Mosaic GPU 流水线#

本指南涵盖了使用 Pallas 的 Mosaic GPU 后端进行软件流水线化的相关内容。

关于 Pallas 流水线 API 的概述,建议用户先阅读 软件流水线。Pallas 中的流水线是显式编程的。对于熟悉 Triton 的用户来说,这是一个显著的编程模型差异,因为在 Triton 中,流水线是由编译器自动执行的优化。

import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental.pallas import mosaic_gpu as plgpu
from jax.experimental import pallas as pl
import numpy as np

Mosaic GPU 流水线#

使用 Mosaic GPU 进行流水线化的推荐方法是使用 plgpu.emit_pipeline 函数对顺序循环进行流水线化(并使用 plgpu.kernel 在 CUDA 网格上并行划分任务)。emit_pipeline 遵循与 pl.pallas_call 类似的 API,区别在于它提供了一些额外的 GPU 特定选项。

  • bodygrid 的语义与 pl.pallas_call 中的类似。grid 表示要运行多少次 body 函数。与 CUDA 网格不同,流水线网格保证按顺序运行。

  • in_specsout_specs 的工作方式也与 pl.pallas_call 类似,区别在于它们还接受 plgpu.BlockSpec 实例,可用于指定 GPU 特定的转换,例如数据重排(swizzling)。有关可用转换的更多详细信息,请参阅 内存引用转换

  • max_concurrent_steps 控制并发内存传输的最大数量。使用额外的并发步骤会消耗更多的共享内存(SMEM)来存放临时缓冲区,但可以提高内存子系统的利用率。我们建议对此参数进行自动调优。较低的值(例如 2)有时可以实现更高的占用率(因为 SMEM 使用量较少),这可以提高 ALU 密集型 Kernel 的吞吐量,但会因为硬件负责调度而引入更多噪声。较大的值(在 4 到 6 之间)最适合无法利用额外占用率的 Kernel。

  • delay_release 允许用户指定在缓冲区被流水线重用之前额外等待的迭代次数。例如,在 delay_release=1max_concurrent_steps=2 的情况下,在迭代 0 时复制到 SMEM 的缓冲区直到迭代 3 才会被重用,而不是标准双缓冲策略中的迭代 2。如果您不对流水线操作数执行 plgpu.wgmma 等待操作,则必须使用 delay_release=1,否则流水线会在 WGMMA 仍在读取缓冲区时就开始覆盖它们。这对于某些优化(例如允许同时进行多个异步矩阵乘法以保持 Tensor Core 流水线填满)很有用,但使用此类策略时必须小心,因为省略此参数会导致静默数据竞争,并且它会降低 emit_pipeline 的效率,因为我们重叠的内存传输更少了。

使用 pl.pallas_call 的兼容性 API#

作为 emit_pipeline 的替代方案,为了保持与 Pallas TPU 的兼容性,Mosaic GPU 也实现了现有的 pl.pallas_call API。默认情况下,Mosaic GPU 上的 pl.pallas_call 将在 CUDA 网格上并行划分您的 Kernel。您可以通过传入 plgpu.CompilerParams 对象作为 compiler_params 参数来选择流水线化,该对象指定了以下与流水线化相关的选项。

  • dimension_semantics:由 Literal['parallel', 'sequential'] 组成的元组,指定每个网格维度的迭代语义。parallel 会在 CUDA 网格上划分相应维度,而 sequential 维度将按顺序进行流水线化。注意,如果没有标记为 sequential 的维度,则不会发生流水线化!

  • max_concurrent_steps:与 plgpu.emit_pipeline 中的选项相同。

  • delay_release:与 plgpu.emit_pipeline 中的选项相同。

流水线允许您在网格的顺序迭代中重用暂存缓冲区(例如用于实现归约)。此外,当使用 Mosaic GPU 后端时,pallas_call 支持使用 plgpu.BlockSpec 对象代替 pl.BlockSpec 对象,允许您指定 GPU 特定的内存转换。

我们建议用户使用 plgpu.kernel 而不是 pl.pallas_call,因为 plgpu.kernel 支持更多功能(例如指定 warp 组数量和 Warp 专用化)。

GPU 内存空间#

引用(Refs)主要存在于两个内存空间之一,可以通过 BlockSpecmemory_space 参数显式指定,即 BlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)

  • plgpu.GPUMemorySpace.SMEM 在共享内存(SMEM)中分配一个 Ref。SMEM Refs 可以使用数组索引语法进行解引用,以将值存储在寄存器中进行计算,例如 x = y_ref[...]。这是使用 emit_pipeline 时 Ref 所使用的内存空间。

  • plgpu.GPUMemorySpace.GMEM 在全局内存(GMEM/HBM)中分配一个 Ref。在 GMEM 中分配的任何 Refs 都不会被流水线化,且无法通过数组索引操作直接访问值。相反,GMEM 必须通过 SMEM 访问:读取时使用 plgpu.copy_gmem_to_smem,写入时使用 plgpu.copy_smem_to_gmem,或者使用 plgpu.emit_pipeline 流水线化到 SMEM 中。

emit_pipeline 的主要目的是将 TensorCore 计算与 GMEM 和 SMEM 之间的数据传输重叠,因为 GMEM/SMEM 之间的异步复制具有很长的延迟,而所有 TensorCore 计算都必须在寄存器(或矩阵乘法情况下的 SMEM Refs)上进行。

示例:Hopper GPU 上的矩阵乘法 Kernel#

让我们从一个旨在在 Hopper GPU 上运行的矩阵乘法示例开始。此 Kernel 利用了 Hopper 特有的 wgmma(Warp 组矩阵乘加)指令。wgmma 由单个 Mosaic GPU 线程发出,并在 TensorCore 上异步运行。

我们的示例 Kernel 实现了两个形状为 [M, K] @ [K, N] = [M, N] 的矩阵的块矩阵乘法,其中每个输出块都在 CUDA 网格上并行计算。此网格被指定为外部 plgpu.kernelgrid 参数,并在矩阵乘法的非缩约维度 M、N 上进行并行化。

在程序实例中,我们使用 plgpu.emit_pipeline 运行一个顺序流水线,该流水线在矩阵乘法的缩约维度 K 上进行归约。在流水线的每次迭代中,我们从每个输入矩阵加载一个 tile,将它们相乘,然后将结果存储在累加器 Ref (plgpu.ACC) 中。plgpu.ACC 是一种特殊的 Ref 类型,它位于寄存器中并保存 WGMMA 的中间结果。一旦我们在整个缩约维度上完成了累加,我们将结果写出到输出 Ref 中。

为了执行实际的矩阵乘法,我们以累加器、LHS 和 RHS Refs 作为参数调用 plgpu.wgmma,以将参数推送到 TensorCore 流水线中。所有 WGMMA 操作都按顺序执行,因此这可以看作是将操作推入队列。由于 wgmma 是异步指令,plgpu.wgmma_wait(N) 用于等待直到正在进行的 wgmma 操作不超过 N 个。在这个特定的实现中,我们等待 1 个正在进行的 WGMMA,这意味着我们在当前迭代中排队的 WGMMA 将在下一次迭代中被等待。

  • wgmma 要求其参数以特定的格式存在,定义在 CUDA 文档中。这些格式由输入 BlockSpecs 上的 TilingTransformSwizzleTransform 转换实现。请注意,未来 Mosaic GPU 将自动推断转换,无需手动指定。有关使用此指令的完整详细信息,请参阅 wgmma 参考

  • 我们将 delay_release 参数与 plgpu.wgmma_wait(1) 结合使用,始终允许一个 WGMMA 操作处于进行中,以确保良好的 TensorCore 利用率。如果没有这一点,我们将在 Kernel 的每次迭代中刷新 TensorCore 流水线。

def matmul(a, b, tile_m=128, tile_n=128, swizzle=128):
  dtype = jnp.float16
  swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
  tile_k = swizzle_elems
  grid_m = m // tile_m
  grid_k = k // tile_k
  grid_n = n // tile_n
  assert tile_m % swizzle_elems == 0

  # Note: Transforms will be inferred automatically
  # by Mosaic GPU in the future.
  transforms = (
    plgpu.TilingTransform((8, swizzle_elems)),
    plgpu.SwizzleTransform(swizzle),
  )

  def kernel(a_gmem, b_gmem, o_gmem, o_smem, acc):
    def pipeline_step(_, a_smem, b_smem):
      plgpu.wgmma(acc, a_smem, b_smem)
      plgpu.wgmma_wait(1)

    # pl.program_id obtains the index into the grid.
    pid_m = pl.program_id(0)
    pid_n = pl.program_id(1)

    pipeline = plgpu.emit_pipeline(
        pipeline_step,
        in_specs=[
            plgpu.BlockSpec(
                (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms
            ),
            plgpu.BlockSpec(
                (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms
            ),
        ],
        grid=(grid_k,),
        max_concurrent_steps=2,
        delay_release=1,
    )

    pipeline(a_gmem, b_gmem)
    # Store WGMMA accumulator to SMEM and then to GMEM.
    o_smem[...] = acc[...].astype(dtype)
    plgpu.commit_smem()
    m_slice = pl.ds(pid_m * tile_m, tile_m)
    n_slice = pl.ds(pid_n * tile_n, tile_n)
    plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])
    plgpu.wait_smem_to_gmem(0)

  return plgpu.kernel(
      kernel,
      out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
      scratch_shapes=dict(
          o_smem=plgpu.SMEM((tile_m, tile_n), jnp.float16),
          acc=plgpu.ACC((tile_m, tile_n), jnp.float32)
      ),
      # grid specifies the CUDA grid.
      # Instances of `kernel` will be executed in parallel over this grid.
      grid=(grid_m, grid_n),
      grid_names=("m", "n"),
  )(a, b)

m = 132 * 128
n = 4 * 128
k = 10 * 64
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)

result = matmul(a, b)

np.testing.assert_allclose(result, a @ b)

Warp 专用化 (Warp Specialization)#

Warp 专用化是一种技术,我们对每个 warp/warpgroup 进行编程以执行单个任务,从而赋予 GPU 硬件在运行时灵活调度它们的能力。回想一下,GPU 中的每个流式多处理器 (SM) 都包含 warp 调度器,可以在 warp 之间切换执行,例如当一个 warp 停顿时,它可以开始执行另一个 warp。在实践中,这比编写单个指令流更具性能,因为在单个指令流中,编译器必须静态调度操作并尝试以最佳方式重叠它们。

特别是,我们对 Hopper+ GPU 上的 Warp 组专用化感兴趣,让一个独立的 Warp 组发出 TMA(GMEM/SMEM 复制)与执行算术运算的 Warp 组分开是有益的,因为索引计算和发出 TMA 可能占用大量时间,并可能导致 TensorCore 空闲。下图描绘了左侧的标准非专用化 Kernel,其中 TMA(异步复制)和矩阵乘法从单个指令流发出,右侧是 Warp 专用化版本,其中通信和算术在不同的 Warp 组上处理。使用消费栅栏 (consumed barrier) 在专用 Warp 组之间进行同步,这会向内存 Warp 组发出信号,表明何时可以安全地开始下一个 TMA。

可以通过使用 plgpu.emit_pipeline_warp_specialized 辅助程序在 Pallas 中启用 Warp 专用化。此流水线辅助程序处理内存线程中的所有逻辑,用户只需要指定计算线程中完成的工作。它与标准 emit_pipeline 共享相似的 API,目前支持以下参数。

plgpu.emit_pipeline_warp_specialized(
  body: Callable,
  *
  grid: tuple[int, ...],
  in_specs: Sequence[pallas_core.BlockSpec] = (),
  out_specs: Sequence[pallas_core.BlockSpec] = (),
  max_concurrent_steps: int,
  compute_context: Callable
  num_compute_wgs: int,
  memory_registers: int
  wg_axis: str,
  memory_thread_idx: int | None = None,
)

此流水线发射器有一些特定参数,它们是:

  • num_compute_wgs 指定要使用的计算线程/Warp 组数量。流水线发射器总是使用单个内存线程,因此在 plgpu.kernel 中,您应该指定 num_threads=num_compute_wgs+1

  • memory_registers 控制分配给内存线程的寄存器数量。剩余的寄存器在计算线程之间平均分配。默认值为 40,应根据是否遇到寄存器溢出(spill)向上或向下调整。

  • wg_axis 是线程/Warp 组轴的名称(由 plgpu.kernelthread_name 参数指定)。

  • memory_thread_idx 指定将哪个 Pallas 线程指定为内存线程。默认为最后一个线程。

  • compute_context 使您能够指定流水线的序幕/尾声(prologue/epilogue),这些代码仅在计算线程中运行。该函数允许您定义流水线中循环进位(loop carry)的初始化和消费。所有计算线程特定的数组都应在此处实例化,以便内存线程不会在寄存器中实现它们——否则,您可能会因寄存器溢出而遇到减速。

Warp 专用化流水线的流水线主体由所有计算线程并行运行,由于它们在同一个 CUDA 块内调度,因此共享 SMEM。lax.axis_index 可在 Kernel 内部用于获取 Pallas 线程索引,以便在计算线程之间分配工作。

示例:使用 Warp 专用化的矩阵乘法#

以下示例扩展了之前的矩阵乘法示例以使用 Warp 专用化。此特定 Kernel 使用 2 个计算线程,它们对 RHS 矩阵的不同列进行操作,但共享相同的 LHS。因此,每次调用流水线都会计算输出矩阵中的 2 个相邻块。

我们使用 compute_context 模式来初始化 WGMMA 累加器,并将最终累加器从寄存器复制到 SMEM 中。在这里,计算上下文在函数 compute_thread 中定义。累加器必须在 compute_thread 函数内部创建,这一点至关重要,以避免将其分配在内存线程中,这将浪费寄存器。为了执行 WGMMA,我们将 wgmma 指令包装在 pl.run_state 中,以创建初始化为进位值的累加器 ref。

我们没有使用 pl.pallas_call 来调用 Kernel,而是使用 GPU 特定的 plgpu.kernel 入口点。plgpu.kernel 允许我们通过 num_threads 参数指定每个 CUDA 块启动的线程数,并允许我们指定一个 thread_name,我们可以在 Kernel 内部使用它来查询 Pallas 线程索引。

def matmul_warp_specialized(a, b, tile_m=128, tile_n=128, swizzle=128,
                            compute_wgs=2):
  dtype = jnp.float16
  elems_128b = swizzle // jnp.dtype(dtype).itemsize
  tile_k = elems_128b
  grid_m = m // tile_m
  grid_k = k // tile_k
  grid_n = n // tile_n
  assert tile_m % elems_128b == 0

  transforms = (
          plgpu.TilingTransform((8, elems_128b)),
          plgpu.SwizzleTransform(128),
      )

  def kernel(a_gmem, b_gmem, o_gmem, o_smem):
    wg_idx = lax.axis_index("wg")
    wg_slice = pl.ds(wg_idx * tile_n, tile_n)
    # pl.program_id obtains the index into the pallas_call grid.
    pid_m = pl.program_id(0)
    pid_n = pl.program_id(1)

    def compute_thread(pipeline):
      acc = plgpu.layout_cast(
          jnp.full((tile_m, tile_n), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
      )
      # yield marks the place where the pipelined loop will be inserted.
      # Its argument are the initial carry values, and its result is the carry
      # value after the loop completes.
      final_acc = pipeline(acc)
      o_smem[:, wg_slice] = final_acc[...].astype(dtype)

    def kernel_body(_, a_smem, b_smem, carry):
      acc = carry
      b_smem_wg = b_smem.at[:, wg_slice]
      def do_wgmma(acc_ref):
        plgpu.wgmma(acc_ref, a_smem, b_smem_wg)
      acc = pl.run_state(do_wgmma)(
                          plgpu.ACC.init(acc))
      return acc

    pipeline = plgpu.emit_pipeline_warp_specialized(
        kernel_body,
        in_specs=[
            plgpu.BlockSpec(
              (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms
            ),
            plgpu.BlockSpec(
              (tile_k, tile_n * 2), lambda k: (k, pid_n),transforms=transforms
            ),
        ],
        grid=(grid_k,),
        compute_context=compute_thread,
        max_concurrent_steps=2,
        num_compute_wgs=compute_wgs,
        memory_registers=40,
        memory_thread_idx=2,
        wg_axis="wg",
    )
    # Call the pipeline
    pipeline(a_gmem, b_gmem)
    # Copy the output from SMEM to GMEM.
    plgpu.commit_smem()
    m_slice = pl.ds(pid_m * tile_m, tile_m)
    n_slice = pl.ds(pid_n * tile_n * 2, tile_n * 2)
    plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])
    plgpu.wait_smem_to_gmem(0)

  return plgpu.kernel(
      kernel,
      out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
      scratch_shapes=dict(
          o_smem=plgpu.SMEM((tile_m, tile_n * 2), jnp.float16)
      ),
      grid=(grid_m, grid_n // 2),
      grid_names=("m", "n"),
      num_threads=3,  # 2 compute, 1 memory.
      thread_name="wg"
  )(a, b)

m = 132 * 128
n = 4 * 128
k = 10 * 64
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)

result = matmul_warp_specialized(a, b)

np.testing.assert_allclose(result, a @ b)