Mosaic GPU 流水线#

本指南涵盖了使用 Pallas 的 Mosaic GPU 后端进行软件流水线。

有关 Pallas 中流水线 API 的一般概述,我们建议用户首先阅读 软件流水线。Pallas 中的流水线是显式编程的。对于熟悉 Triton 的用户来说,这与 Triton 的编程模型有显著区别,因为在 Triton 中,流水线是由编译器自动完成的优化。

import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental.pallas import mosaic_gpu as plgpu
from jax.experimental import pallas as pl
import numpy as np

使用 Mosaic GPU 进行流水线#

使用 Mosaic GPU 进行流水线的推荐方法是使用 plgpu.emit_pipeline 函数对顺序循环进行流水线(并使用 plgpu.kernel 在 CUDA 网格上并行划分问题)。emit_pipeline 的 API 与 pl.pallas_call 类似,但它暴露了一些额外的 GPU 特定选项。

  • bodygrid 的语义与 pl.pallas_call 中类似。grid 表示要运行多少次 body 函数的调用。与 CUDA 网格不同,流水线网格保证顺序执行。

  • in_specsout_specs 的工作方式也与 pl.pallas_call 类似,但它们也接受 plgpu.BlockSpec 实例,可用于指定 GPU 特定的转换,例如混淆。有关可用转换的更多详细信息,请参阅 内存引用转换

  • max_concurrent_steps 控制并发内存传输的最大数量。使用额外的并发步骤将消耗更多的 SMEM 来存储临时缓冲区,但这可以提高内存子系统的利用率。我们建议对该参数进行自动调整。较低的值(例如 2)有时可以实现更高的占用率(由于 SMEM 使用量较低),这可以提高 ALU 密集型内核的吞吐量,但会引入更多由于硬件处理调度而产生的噪声。较大的值(在 4 到 6 之间)最适合无法利用额外占用率的内核。

  • delay_release 允许用户指定在流水线重用缓冲区之前等待的额外迭代次数。例如,在迭代 0 中复制到 SMEM 的缓冲区,如果设置 delay_release=1max_concurrent_steps=2,则直到迭代 3 才会重用,而不是标准双缓冲策略下的迭代 2。delay_release=1 在您没有等待流水线操作数上的 plgpu.wgmma 操作时是必需的,否则流水线将在 WGMMA 仍在读取缓冲区时开始覆盖它们。这对于某些优化很有用,例如允许多个异步矩阵乘法操作同时进行以填充张量核心流水线,但使用此类策略时必须小心,因为省略此参数将导致数据竞争静默,并且它降低了 emit_pipeline 的效率,因为我们重叠的内存传输更少。

使用 pl.pallas_call 的兼容 API#

作为 emit_pipeline 的替代方案,并为了与 Pallas TPU 保持兼容,Mosaic GPU 还实现了现有的 pl.pallas_call API。默认情况下,Mosaic GPU 上的 pl.pallas_call 会在 CUDA 网格上并行划分您的内核。您可以通过将 plgpu.GPUCompilerParams 对象作为 compiler_params 参数传入来选择加入流水线,该参数指定了以下与流水线相关的选项:

  • dimension_semantics:一个包含 Literal['parallel', 'sequential'] 的元组,指定每个网格维度的迭代语义。parallel 将在 CUDA 网格上划分相应的维度,而 sequential 维度将按顺序进行流水线。注意,如果没有任何维度被标记为 sequential,则不会发生流水线!

  • max_concurrent_steps:与 plgpu.emit_pipeline 中的选项相同。

  • delay_release:与 plgpu.emit_pipeline 中的选项相同。

流水线允许您在网格的顺序迭代之间重用临时缓冲区(例如,用于实现规约)。此外,pallas_call 在使用 Mosaic GPU 后端时支持使用 plgpu.BlockSpec 对象代替 pl.BlockSpec 对象,从而允许您指定 GPU 特定的内存转换。

我们建议用户使用 plgpu.kernel 而不是 pl.pallas_call,因为 plgpu.kernel 支持更多功能(例如指定 warp 组数量和 warp 特化)。

GPU 内存空间#

Refs 主要存在于两种内存空间之一,可以通过 `BlockSpec` 的 `memory_space` 参数显式指定,即 `BlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)`。

  • plgpu.GPUMemorySpace.SMEM 在共享内存(SMEM)中分配一个 Ref。SMEM Refs 可以使用数组索引语法进行解引用,将值存储在寄存器中进行计算,即 `x = y_ref[...]`。当使用 `emit_pipeline` 时,此内存空间用于 Ref。

  • plgpu.GPUMemorySpace.GMEM 在全局内存(GMEM/HBM)中分配一个 Ref。在 GMEM 中分配的任何 Ref 都不会进行流水线处理,并且无法直接通过数组索引操作访问值。相反,必须通过 SMEM 访问 GMEM,使用 plgpu.copy_gmem_to_smem 进行读取,或使用 plgpu.copy_smem_to_gmem 进行写入,或通过 plgpu.emit_pipeline 流水线传输到 SMEM。

emit_pipeline 的主要目的是用于重叠张量核心计算与 GMEM 和 SMEM 之间的数据传输,因为 GMEM/SMEM 之间的异步复制具有较长的延迟,但所有张量核心计算必须在寄存器上进行(或在矩阵乘法的情况下为 SMEM Refs)。

示例:Hopper GPU 上的矩阵乘法内核#

让我们从一个专为 Hopper GPU 设计的矩阵乘法示例开始。此内核利用了 Hopper 特有的 `wgmma`(warpgroup 矩阵乘法累加)指令。`wgmma` 由单个 Mosaic GPU 线程发出,并在张量核心上异步运行。

我们的示例内核实现了两个形状为 `[M, K] @ [K, N] = [M, N]` 的矩阵的块状矩阵乘法,其中每个输出块在 CUDA 网格上并行计算。此网格作为外部 `plgpu.kernel` 的 `grid` 参数指定,并并行化矩阵乘法的非收缩维度 M、N。

在程序实例内,我们使用 `plgpu.emit_pipeline` 运行顺序流水线,该流水线对矩阵乘法的收缩维度 K 进行规约。在流水线的每次迭代中,我们加载每个输入矩阵的一个图块,将它们相乘,然后将结果存储在累加器 Ref (`plgpu.ACC`) 中。`plgpu.ACC` 是一种特殊的 Ref 类型,它驻留在寄存器中并保存 WGMMA 的中间结果。一旦我们完成了整个收缩维度的累加,我们就将结果写出到输出 Ref。

为了执行实际的矩阵乘法,我们使用累加器、LHS 和 RHS Ref 作为参数调用 `plgpu.wgmma`,以便将参数推入张量核心流水线。所有 WGMMA 操作按顺序执行,因此这可以看作是将操作推入队列。由于 `wgmma` 是一条异步指令,因此使用 `plgpu.wgmma_wait(N)` 来等待,直到没有超过 N 个 `wgmma` 操作处于飞行状态。在此特定实现中,我们等待 1 个处于飞行状态的 WGMMA,这意味着我们在当前迭代中排队的 WGMMA 将在下一迭代中等待。

  • `wgmma` 要求其参数采用特定格式,该格式在 CUDA 文档 中定义。这些通过输入 BlockSpecs 上的 `TilingTransform` 和 `SwizzleTransform` 转换来实现。请注意,将来转换将由 Mosaic GPU 自动推断,无需手动指定。有关使用此指令的完整详细信息,请参阅 wgmma 参考

  • 我们将 `delay_release` 参数与 `plgpu.wgmma_wait(1)` 结合使用,以始终允许一个 `WGMMA` 操作保持飞行状态,从而确保良好的张量核心利用率。否则,我们将在内核的每次迭代中冲刷张量核心流水线。

def matmul(a, b, tile_m=128, tile_n=128, swizzle=128):
  dtype = jnp.float16
  swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
  tile_k = swizzle_elems
  grid_m = m // tile_m
  grid_k = k // tile_k
  grid_n = n // tile_n
  assert tile_m % swizzle_elems == 0

  # Note: Transforms will be inferred automatically
  # by Mosaic GPU in the future.
  transforms = (
    plgpu.TilingTransform((8, swizzle_elems)),
    plgpu.SwizzleTransform(swizzle),
  )

  def kernel(a_gmem, b_gmem, o_gmem, o_smem, acc):
    def pipeline_step(_, a_smem, b_smem):
      plgpu.wgmma(acc, a_smem, b_smem)
      plgpu.wgmma_wait(1)

    # pl.program_id obtains the index into the grid.
    pid_m = pl.program_id(0)
    pid_n = pl.program_id(1)

    pipeline = plgpu.emit_pipeline(
        pipeline_step,
        in_specs=[
            plgpu.BlockSpec(
                (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms
            ),
            plgpu.BlockSpec(
                (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms
            ),
        ],
        grid=(grid_k,),
        max_concurrent_steps=2,
        delay_release=1,
    )

    pipeline(a_gmem, b_gmem)
    # Store WGMMA accumulator to SMEM and then to GMEM.
    o_smem[...] = acc[...].astype(dtype)
    plgpu.commit_smem()
    m_slice = pl.ds(pid_m * tile_m, tile_m)
    n_slice = pl.ds(pid_n * tile_n, tile_n)
    plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])
    plgpu.wait_smem_to_gmem(0)

  return plgpu.kernel(
      kernel,
      out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
      scratch_shapes=dict(
          o_smem=plgpu.SMEM((tile_m, tile_n), jnp.float16),
          acc=plgpu.ACC((tile_m, tile_n), jnp.float32)
      ),
      # grid specifies the CUDA grid.
      # Instances of `kernel` will be executed in parallel over this grid.
      grid=(grid_m, grid_n),
      grid_names=("m", "n"),
  )(a, b)

m = 132 * 128
n = 4 * 128
k = 10 * 64
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)

result = matmul(a, b)

np.testing.assert_allclose(result, a @ b)

Warp 特化#

Warp 特化是一种技术,我们为每个 warp/warpgroup 编程以执行单个任务,从而使 GPU 硬件在运行时具有调度灵活性。回想一下,GPU 中的每个流式多处理器 (SM) 都包含 warp 调度器,可以在 warp 之间切换执行,因此例如当一个 warp 停滞时,它可以开始执行另一个 warp。在实践中,这可能比编程单个指令流更有效,在单个指令流中,编译器必须静态调度操作并尝试最优地重叠它们。

特别是,我们对 Hopper+ GPU 上的 warpgroup 特化感兴趣,在这种情况下,让单独的 warpgroup 发出 TMA(GMEM/SMEM 副本),而其他 warpgroup 执行算术运算可能很有用,因为索引计算和发出 TMA 可能需要大量时间,并可能使 TensorCore 保持空闲。下图描绘了左侧一个标准的、非特化的内核,其中 TMA(异步副本)和矩阵乘法从单个指令流发出,以及右侧一个 warp 特化的版本,其中通信和算术在单独的 warpgroup 上处理。使用一个*消耗的屏障*在特化的 warpgroup 之间进行同步,该屏障向内存 warpgroup 发出信号,表示何时可以开始下一个 TMA。

Pallas 可以通过使用 `plgpu.emit_pipeline_warp_specialized` 辅助函数来启用 Warp 特化。此流水线辅助函数处理内存线程中的所有逻辑,用户只需指定计算线程中完成的工作。它具有与标准 `emit_pipeline` 类似的 API,目前支持以下参数:

plgpu.emit_pipeline_warp_specialized(
  body: Callable,
  *
  grid: tuple[int, ...],
  in_specs: Sequence[pallas_core.BlockSpec] = (),
  out_specs: Sequence[pallas_core.BlockSpec] = (),
  max_concurrent_steps: int,
  compute_context: Callable
  num_compute_wgs: int,
  memory_registers: int
  wg_axis: str,
  memory_thread_idx: int | None = None,
)

此流水线发射器有一些特定于它的参数,即:

  • num_compute_wgs 指定要使用的计算线程/warpgroup 的数量。流水线发射器始终使用单个内存线程,因此在 `plgpu.kernel` 中,您应该指定 `num_threads=num_compute_wgs+1`。

  • memory_registers 控制分配给内存线程的寄存器数量。剩余的寄存器平均分配给计算线程。默认值为 40,应根据是否遇到寄存器溢出进行调整。

  • wg_axis 线程/warpgroup 轴的名称(由 `plgpu.kernel` 的 `thead_name` 参数指定)。

  • memory_thread_idx 指定要指定为内存线程的 Pallas 线程。默认为最后一个线程。

  • compute_context 允许您为仅在计算线程中运行的流水线指定序言/尾声。该函数允许您定义流水线中的循环携带的初始化和消耗。所有计算线程特定的数组都应在此处实例化,以便内存线程不会在寄存器中具体化它们——否则,您可能会因寄存器溢出而导致性能下降。

warp 特化流水线的流水线主体由所有计算线程并行运行,并且 SMEM 在计算线程之间共享,因为它们在同一个 CUDA 块内进行调度。`lax.axis_index` 可以在内核内部使用,以获取 Pallas 线程索引,从而在计算线程之间划分工作。

示例:带有 Warp 特化的矩阵乘法#

以下示例将之前的矩阵乘法示例扩展为使用 warp 特化。此特定内核使用 2 个计算线程,它们操作 RHS 矩阵的不同列但共享同一个 LHS。因此,流水线的每次调用都计算输出矩阵中的 2 个相邻块。

我们使用 `compute_context` 模式来初始化 WGMMA 累加器,并将最终累加器从寄存器复制到 SMEM。在这里,计算上下文在 `compute_thread` 函数中定义。至关重要的是,累加器必须在 `compute_thread` 函数内部创建,以避免在内存线程中分配它,这会浪费寄存器。为了执行 WGMMA,我们将 `wgmma` 指令包装在 `pl.run_state` 中,以创建一个初始化为携带值的累加器 ref。

我们使用 GPU 特定的 `plgpu.kernel` 入口点,而不是使用 `pl.pallas_call` 来调用内核。plgpu.kernel 允许我们通过 `num_threads` 参数指定每个 CUDA 块启动的线程数,并允许我们指定一个 `thread_name`,我们可以在内核中用它来查询 Pallas 线程索引。

def matmul_warp_specialized(a, b, tile_m=128, tile_n=128, swizzle=128,
                            compute_wgs=2):
  dtype = jnp.float16
  elems_128b = swizzle // jnp.dtype(dtype).itemsize
  tile_k = elems_128b
  grid_m = m // tile_m
  grid_k = k // tile_k
  grid_n = n // tile_n
  assert tile_m % elems_128b == 0

  transforms = (
          plgpu.TilingTransform((8, elems_128b)),
          plgpu.SwizzleTransform(128),
      )

  def kernel(a_gmem, b_gmem, o_gmem, o_smem):
    wg_idx = lax.axis_index("wg")
    wg_slice = pl.ds(wg_idx * tile_n, tile_n)
    # pl.program_id obtains the index into the pallas_call grid.
    pid_m = pl.program_id(0)
    pid_n = pl.program_id(1)

    def compute_thread(pipeline):
      acc = plgpu.layout_cast(
          jnp.full((tile_m, tile_n), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
      )
      # yield marks the place where the pipelined loop will be inserted.
      # Its argument are the initial carry values, and its result is the carry
      # value after the loop completes.
      final_acc = pipeline(acc)
      o_smem[:, wg_slice] = final_acc[...].astype(dtype)

    def kernel_body(_, a_smem, b_smem, carry):
      acc = carry
      b_smem_wg = b_smem.at[:, wg_slice]
      def do_wgmma(acc_ref):
        plgpu.wgmma(acc_ref, a_smem, b_smem_wg)
      acc = pl.run_state(do_wgmma)(
                          plgpu.ACC.init(acc))
      return acc

    pipeline = plgpu.emit_pipeline_warp_specialized(
        kernel_body,
        in_specs=[
            plgpu.BlockSpec(
              (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms
            ),
            plgpu.BlockSpec(
              (tile_k, tile_n * 2), lambda k: (k, pid_n),transforms=transforms
            ),
        ],
        grid=(grid_k,),
        compute_context=compute_thread,
        max_concurrent_steps=2,
        num_compute_wgs=compute_wgs,
        memory_registers=40,
        memory_thread_idx=2,
        wg_axis="wg",
    )
    # Call the pipeline
    pipeline(a_gmem, b_gmem)
    # Copy the output from SMEM to GMEM.
    plgpu.commit_smem()
    m_slice = pl.ds(pid_m * tile_m, tile_m)
    n_slice = pl.ds(pid_n * tile_n * 2, tile_n * 2)
    plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])
    plgpu.wait_smem_to_gmem(0)

  return plgpu.kernel(
      kernel,
      out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
      scratch_shapes=dict(
          o_smem=plgpu.SMEM((tile_m, tile_n * 2), jnp.float16)
      ),
      grid=(grid_m, grid_n // 2),
      grid_names=("m", "n"),
      num_threads=3,  # 2 compute, 1 memory.
      thread_name="wg"
  )(a, b)

m = 132 * 128
n = 4 * 128
k = 10 * 64
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)

result = matmul_warp_specialized(a, b)

np.testing.assert_allclose(result, a @ b)