TPU 流水线#

本指南作为 TPU 特定流水线问题的参考。我们将回顾 TPU 上的内存层次结构和计算单元,以及流水线 API 的 TPU 特定功能。有关流水线的更通用概述,请参阅 软件流水线

#@title Imports

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np

TPU 及其内存空间#

TPU 及其 TensorCore 由内存空间(数组可以驻留的地方)、寄存器(临时存储标量和数组值)和计算单元(用寄存器中的值进行计算)组成。下面是 TPU 的一个图示,其中 xy 是驻留在高带宽内存(HBM)中的数组。

TPU Memory Space Cartoon.png

让我们更详细地讨论图中的组件。

  • 内存空间:TPU 具有高带宽内存(HBM),这通常是我们所说的“设备内存”。还有向量内存(VMEM),这是一个用于存储向量和数组值的缓存,以及标量内存(SMEM),这是一个设计用于存储标量值的缓存。

  • 寄存器:TensorCore 有两种主要的寄存器类型:向量寄存器(VREGs)存储数组值,标量寄存器(SREGs)存储标量值。值可以从各自的缓存(VMEM 用于 VREGs,SMEM 用于 SREGs)加载到内存中。

  • 计算单元:TensorCore 具有标量单元、向量单元(VPU)和矩阵单元(MXU),可以进行数值计算。这些计算单元中的每一个都可以异步运行,但这由 TPU 编译器管理,因此从程序员的角度来看,TPU 程序是单线程的。计算单元操作驻留在 SREGs 和 VREGs 中的值,并将输出值也放入这些寄存器中。

TPU 特定的流水线功能#

Pallas TPU 支持以下平台特定功能。

TPU 内存空间#

Pallas 将 TPU 内存层次结构的所有级别公开给用户。下表将 Pallas TPU 内存空间映射到它们的标准内存类型(DRAM/SRAM)。

Pallas 枚举

TPU 内存空间

类型(DRAM/SRAM)

pltpu.MemorySpace.ANY

HBM(通常)或 VMEM

DRAM

pltpu.MemorySpace.VMEM

VMEM

SRAM

pltpu.MemorySpace.SMEM

SMEM

SRAM

pltpu.MemorySpace.SEMAPHORE

信号量

SRAM

  • MemorySpace.VMEM 表示向量 SRAM。如果未指定,则它是默认内存空间。

  • MemorySpace.SMEM 表示标量 SRAM。只能对 SMEM 执行标量加载和存储。

  • MemorySpace.ANY 是对编译器的提示,表明内存空间不受限制。在大多数情况下,XLA 会将此缓冲区放置在 HBM 中。分配给 ANY 内存空间的缓冲区不能使用数组索引语法(例如 x[...])正常取消引用。相反,我们必须首先使用 pltpu.sync_copypltpu.async_copy 将值复制到 VMEM 或 SMEM 缓冲区。

  • MemorySpace.SEMAPHORE 用于分配信号量以构建屏障或跟踪异步操作。也可以从内核返回信号量以构建异步内核 - 这是一个实验性功能;有关更多详细信息,请参阅 Pallas 异步操作

TPU 上的流水线通常在 HBM(DRAM)和 VMEM(向量 SRAM)之间进行。TPU 上 pallas_call 的默认行为是 pallas_call 的参数假定驻留在 HBM 中,并且用户内核体的输入存储在 VMEM 中。

虽然这并非 TPU 流水线特有,但有可能手动控制输入和输出缓冲区的内存空间。您可以在 BlockSpec 上指定 memory_space 参数。请注意,除非将 memory_space 标记为 VMEM,否则不允许流水线。内存空间也可用于通过 pallas_call 上的 scratch_shapes 参数指定内核的暂存参数。暂存缓冲区在内核迭代之间持久存在,可用于存储中间结果,例如部分累加和约简。暂存缓冲区必须驻留在 VMEMSMEMSEMAPHORE 中。

作为在内核中使用多个手动内存空间分配的示例,以下程序将 HBM 缓冲区 x_hbm_ref 的一个切片复制到暂存 VMEM 缓冲区 scratch_vmem_ref 中,然后使用它进行算术运算并将结果存储到输出 VMEM 缓冲区中。

def hbm_vmem_kernel(x_hbm_ref, out_vmem_ref, scratch_vmem_ref):
  pltpu.sync_copy(x_hbm_ref.at[0:1], scratch_vmem_ref)
  out_vmem_ref[...] = scratch_vmem_ref[...] + 1

x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)
out = pl.pallas_call(hbm_vmem_kernel,
  in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)],
  out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32),
  scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),)
)(x)

np.testing.assert_allclose(out, x[0:1] + 1)

多缓冲#

可以通过 pl.BlockSpec 上的 pipeline_mode 选项按参数指定流水线的多缓冲。为此,请将 pl.Buffered 对象传递给 pl.BlockSpec,指定为此特定参数分配的缓冲区数量。

pl.BlockSpec(
  pipeline_mode=pl.Buffered(buffer_count=buffer_count)
)

所有输入和输出的默认缓冲区数量为 2。

pltpu.emit_pipeline#

pltpu.emit_pipeline 是 Pallas 中实现的一个流水线 API,它允许您在内核内部构建流水线,而不仅仅是在内核入口处。这比使用 pl.pallas_call 有几个用例,例如:

  • 用于构建嵌套流水线。例如,一个在芯片之间通信的外部流水线,以及一个执行 HBM-VMEM 流水线的内部流水线。

  • 用于使用 emit_pipeline 特有的功能,例如前瞻预取和动态块形状(下面将介绍)。

pltpu.emit_pipeline 遵循与 pl.pallas_call 类似的签名,并要求您指定一个 kernel 主体、一个网格以及输入和输出的块规范。

def emit_pipeline(
    kernel: Callable,
    grid: tuple[int],
    in_specs: PyTree[BlockSpec] = None,
    out_specs: PyTree[BlockSpec] = None,
    dimension_semantics: tuple[GridDimensionSemantics] = None,
    core_axis: int | None = None,
) -> Callable:
  ... # Returns a custom pipeline given an inner kernel and BlockSpecs.

dimension_semanticscore_axis 参数用于将内核网格分区到 Megacore 上(见下文)。

前瞻预取#

前瞻预取是一项流水线功能,其中流水线将尝试在缓冲槽可用时预取下一个输入块,而不是在直接使用它之前的迭代。例如,如果内核的网格为 (8,),并且每次迭代要获取的块索引为 0, 0, 0, 0, 1, 1, 1, 1,那么前瞻预取将在迭代 0 上开始获取块 01,而标准流水线调度将在迭代 0 上获取块 0,但直到迭代 3 才开始获取块 1。执行前瞻会有少量控制流开销,因此默认情况下是禁用的。

当每个块的计算工作量可变时,前瞻预取尤其有用,例如当某些块包含跳过的工作或减少的工作量时。在这些情况下,在前一个迭代中可能没有足够的工作来在需要该块的步骤之前完全与内存传输重叠。因此,我们希望在流水线的早期开始获取块。

前瞻预取可以与多缓冲结合使用,也可以通过将 pl.Buffered 传递给 pipeline_mode 参数来启用。

pl.BlockSpec(
  pipeline_mode=pl.Buffered(buffer_count=buffer_count, use_lookahead=True)
)

动态块形状#

pltpu.emit_pipeline 支持对具有动态但有界的形状的块进行流水线处理。为了指定这种块形状,块中大小动态的维度应该用 pl.BoundedSlice(max_size) 标记,而不是静态整数大小,其中 max_size 是块的最大大小。此外,index_map 返回的相应索引应该是通过 pl.ds(start, size) 构建的动态切片,其中 startsize 都是元素索引(而不是块索引),并且可以是动态的。

以下是一个具有动态第一个维度的块规范的示例。

pl.BlockSpec(
   block_shape=(pl.BoundedSlice(32), 256),
   index_map=lambda *grid_idxs: (pl.ds(start, end), 0),
)
# The following kernel copies `x` to the output in dynamic-sized chunks
# passed in via `slices`.

def dynamic_block_example_kernel(x_hbm, slices_hbm, o_hbm, slices_smem):
    pltpu.sync_copy(slices_hbm, slices_smem)  # Copy slices into SMEM.
    def pipeline_body(x_vmem, o_vmem):
        o_vmem[...] = x_vmem[...]
    def index_map(i):
        start = slices_smem[i, 0]
        size = slices_smem[i, 1] - slices_smem[i, 0]
        return (pl.ds(start, size), 0)
    block_spec = pl.BlockSpec(block_shape=(pl.BoundedSlice(8), 128),
                              index_map=index_map)
    pltpu.emit_pipeline(
        pipeline_body,
        grid=(slices.shape[0],),
        in_specs=[block_spec],
        out_specs=block_spec
    )(x_hbm, o_hbm)

x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)
slices = jnp.array([[0, 2], [2, 3], [3, 5], [5, 8]], dtype=jnp.int32)

hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)
out = pl.pallas_call(dynamic_block_example_kernel,
               in_specs=[hbm_block_spec, hbm_block_spec],
               out_specs=hbm_block_spec,
               out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
               scratch_shapes=(pltpu.MemorySpace.SMEM(slices.shape, jnp.int32),)
              )(x, slices)

np.testing.assert_allclose(x, out)

Megacore 配置中的 TPU#

某些 TPU 芯片有两个 TensorCores,但对 JAX 用户来说表现为一个设备。这被称为“megacore”。单独的 TensorCores 拥有各自独立的 VMEM、VREGs、SMEM、SREGs 和计算单元,但*共享 HBM*。

TPU Memory Space Cartoon (Megacore).png

从概念上讲,Megacore 中的 TPU 行为类似于非常简单的 GPU,即它们只有两个线程。我们如何修改内核以同时利用两个 TensorCores?

基本思想是,如果我们计算中存在独立并行的维度,我们可以将这些维度分配到 TensorCores 上。通过为 pallas_call 提供一个名为 dimension_semantics 的注解,我们可以指明哪些维度是可并行的。

def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):
  # Load x and y from VMEM into VREGs
  x_vregs = x_vmem_ref[:, :]
  y_vregs = y_vmem_ref[:, :]
  # Execute a vectorized add
  z_vregs = x_vregs + y_vregs
  # Store the output values in VREGs back into VMEM
  z_vmem_ref[:, :] = z_vregs

def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
  block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(2,),
      compiler_params=pltpu.CompilerParams(
          dimension_semantics=("parallel",))
  )(x, y)

x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

dimension_semantics 应该是一个与 grid 长度相同的元组,其中每个条目是 "parallel""arbitrary""parallel" 向 Pallas 表明,对应于该维度的 for 循环的迭代可以独立执行,而不会影响程序的正确性。"arbitrary" 向 Pallas 表明,不能对该网格维度做任何假设,因此不能对其进行并行化。

通过指定 dimension_semantics,我们现在可以在每个 TensorCore 上同时执行内核。Pallas 将自动处理网格的分割。

请注意,Megacore 目前仅在 TPU v4 和 TPU v5p 上可用。在其他平台上提供 dimension_semantics 注解是无操作的,但*不*指定它将导致只使用一个 TensorCore(即使有多个可用)。

在使用 pltpu.emit_pipeline 时,应将 core_axis 传递给 emit_pipelinecore_axis 应该是用于分区网格的并行网格轴的索引。例如,以下模板可用于在领先的并行网格维度上分区内核。

def kernel_body(...):
  def inner_pipeline_body(...):
    ...
  pltpu.emit_pipeline(inner_pipeline_body,
                      grid=(4, 4), 
                      core_axis=0,
                      dimension_semantics=("parallel", "sequential"))

pl.pallas_call(
      kernel_body,
      grid=(num_cores,),
      compiler_params=pltpu.CompilerParams(
          dimension_semantics=("parallel",))
  )