Mosaic GPU 流水线#

本指南介绍了使用 Pallas 的 Mosaic GPU 后端进行软件流水线操作。

有关 Pallas 中流水线 API 的一般概述,我们建议用户首先阅读软件流水线。Pallas 中的流水线是显式编程的。对于熟悉 Triton 的用户来说,这是一个显著的编程模型差异,因为在 Triton 中,流水线是编译器自动完成的优化。

import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental.pallas import mosaic_gpu as plgpu
from jax.experimental import pallas as pl
import numpy as np

使用 Mosaic GPU 进行流水线操作#

使用 Mosaic GPU 进行流水线操作的推荐方法是使用 plgpu.emit_pipeline 函数对顺序循环进行流水线操作(并使用 plgpu.kernel 在 CUDA 网格中并行划分问题)。emit_pipeline 遵循与 pl.pallas_call 类似的 API,只是它公开了一些额外的 GPU 特定选项。

  • bodygrid 具有与 pl.pallas_call 中类似的语义。grid 表示要运行 body 函数的调用次数。与 CUDA 网格不同,流水线网格保证按顺序运行。

  • in_specsout_specs 的工作方式也与 pl.pallas_call 类似,只是它们还接受 plgpu.BlockSpec 实例,可用于指定 GPU 特定的转换,例如交错。有关可用转换的更多详细信息,请参阅内存引用转换

  • max_concurrent_steps 控制最大并发内存传输数量。使用额外的并发步骤将消耗更多 SMEM 来保存临时缓冲区,但可以提高内存子系统的利用率。我们建议对该参数进行自动调优。较低的值(例如 2)有时可以实现更高的占用率(由于较低的 SMEM 使用),这可以提高 ALU 密集型内核的吞吐量,但由于硬件负责调度,会引入更多噪声。较大的值(4 到 6 之间)对于无法利用额外占用率的内核效果最佳。

  • delay_release 允许用户指定在缓冲区被流水线重用之前额外等待的迭代次数。例如,在迭代 0 上复制到 SMEM 的缓冲区,如果 delay_release=1max_concurrent_steps=2,将直到迭代 3 才被重用,而标准双缓冲策略下是迭代 2。delay_release=1 是必要的,如果你不对流水线操作数上的 plgpu.wgmma 操作进行等待,否则流水线将在 WGMMA 仍在读取缓冲区时开始覆盖它们。这对于某些优化很有用,例如允许多个异步矩阵乘法在进行中以保持张量核心流水线的填充,但使用这种策略时必须小心,因为省略此参数会导致静默数据竞争,并且会降低 emit_pipeline 的效率,因为我们重叠的内存传输更少。

使用 pl.pallas_call 的兼容性 API#

作为 emit_pipeline 的替代方案,并为了保持与 Pallas TPU 的兼容性,Mosaic GPU 还实现了现有的 pl.pallas_call API。默认情况下,Mosaic GPU 上的 pl.pallas_call 将在 CUDA 网格中并行划分您的内核。您可以通过传递 plgpu.GPUCompilerParams 对象作为 compiler_params 参数来选择启用流水线,该对象指定了与流水线相关的以下选项:

  • dimension_semantics:一个由 Literal['parallel', 'sequential'] 组成的元组,用于指定每个网格维度的迭代语义。parallel 将在 CUDA 网格中划分相应的维度,而 sequential 维度将按顺序进行流水线处理。请注意,如果没有维度被标记为 sequential,则不会发生流水线!

  • max_concurrent_steps:与 plgpu.emit_pipeline 中的选项相同。

  • delay_release:与 plgpu.emit_pipeline 中的选项相同。

流水线允许您在网格的顺序迭代中重用暂存缓冲区(例如,用于实现归约)。此外,当使用 Mosaic GPU 后端时,pallas_call 支持使用 plgpu.BlockSpec 对象代替 pl.BlockSpec 对象,从而允许您指定 GPU 特定的内存转换。

我们建议用户使用 plgpu.kernel 而不是 pl.pallas_call,因为 plgpu.kernel 支持更多功能(例如指定 warpgroup 的数量和 warp 专用化)。

GPU 内存空间#

Ref 主要存在于两种内存空间之一,这可以通过 BlockSpecmemory_space 参数显式指定,即 BlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)

  • plgpu.GPUMemorySpace.SMEM 在共享内存 (SMEM) 中分配一个 Ref。SMEM Ref 可以使用数组索引语法进行解引用,以将值存储在寄存器中进行计算,即 x = y_ref[...]。当使用 emit_pipeline 时,此内存空间用于 Ref。

  • plgpu.GPUMemorySpace.GMEM 在全局内存 (GMEM/HBM) 中分配一个 Ref。任何在 GMEM 中分配的 Ref 都不会进行流水线处理,并且无法直接使用数组索引操作访问值。相反,GMEM 必须通过 SMEM 访问,使用 plgpu.copy_gmem_to_smem 进行读取,或使用 plgpu.copy_smem_to_gmem 进行写入,或使用 plgpu.emit_pipeline 流水线到 SMEM 中。

emit_pipeline 的主要目的是将 TensorCore 计算与 GMEM 和 SMEM 之间的数据传输重叠,因为 GMEM/SMEM 之间的异步复制具有较长的延迟,但所有 TensorCore 计算都必须在寄存器上操作(或者在矩阵乘法的情况下,在 SMEM Ref 上操作)。

示例:Hopper GPU 上的矩阵乘法内核#

我们从一个专为 Hopper GPU 运行的矩阵乘法示例开始。该内核利用 Hopper 特定的 wgmma(warpgroup 矩阵乘累加)指令。wgmma 由单个 Mosaic GPU 线程发出,并在 TensorCore 上异步运行。

我们的示例内核实现了两个形状为 [M, K] @ [K, N] = [M, N] 的矩阵的块式矩阵乘法,其中每个输出块在 CUDA 网格上并行计算。此网格被指定为外部 plgpu.kernelgrid 参数,并在矩阵乘法的非收缩维度 M、N 上并行化。

在程序实例中,我们使用 plgpu.emit_pipeline 运行顺序流水线,该流水线对矩阵乘法的收缩维度 K 进行归约。在流水线的每次迭代中,我们从每个输入矩阵加载一个分块,进行乘法运算,然后将结果存储在累加器 Ref(plgpu.ACC)中。plgpu.ACC 是一种特殊类型的 Ref,它存在于寄存器中,并保存 WGMMA 的中间结果。一旦我们在整个收缩维度上累加完毕,就将结果写入输出 Ref。

为了执行实际的矩阵乘法,我们调用 plgpu.wgmma,并将累加器、LHS 和 RHS Ref 作为参数,以便将参数推送到 TensorCore 流水线中。所有 WGMMA 操作都按顺序执行,因此这可以看作是将操作推入队列。由于 wgmma 是一个异步指令,plgpu.wgmma_wait(N) 用于等待,直到飞行中的 wgmma 操作不超过 N 个。在此特定实现中,我们等待 1 个飞行中的 WGMMA,这意味着我们在当前迭代中排队的 WGMMA 将在下一次迭代中被等待。

  • wgmma 要求其参数采用CUDA 文档中定义的特定格式。这些格式由输入 BlockSpec 上的 TilingTransformSwizzleTransform 转换实现。请注意,将来 Mosaic GPU 将自动推断转换,无需手动指定。有关使用此指令的完整详细信息,请参阅 wgmma 参考

  • 我们结合使用 delay_release 参数和 plgpu.wgmma_wait(1),始终允许一个 WGMMA 操作保持在飞行中,以确保 TensorCore 的良好利用率。如果没有这个,我们将在内核的每次迭代中刷新 TensorCore 流水线。

def matmul(a, b, tile_m=128, tile_n=128, swizzle=128):
  dtype = jnp.float16
  swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
  tile_k = swizzle_elems
  grid_m = m // tile_m
  grid_k = k // tile_k
  grid_n = n // tile_n
  assert tile_m % swizzle_elems == 0

  # Note: Transforms will be inferred automatically
  # by Mosaic GPU in the future.
  transforms = (
    plgpu.TilingTransform((8, swizzle_elems)),
    plgpu.SwizzleTransform(swizzle),
  )

  def kernel(a_gmem, b_gmem, o_gmem, o_smem, acc):
    def pipeline_step(_, a_smem, b_smem):
      plgpu.wgmma(acc, a_smem, b_smem)
      plgpu.wgmma_wait(1)

    # pl.program_id obtains the index into the grid.
    pid_m = pl.program_id(0)
    pid_n = pl.program_id(1)

    pipeline = plgpu.emit_pipeline(
        pipeline_step,
        in_specs=[
            plgpu.BlockSpec(
                (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms
            ),
            plgpu.BlockSpec(
                (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms
            ),
        ],
        grid=(grid_k,),
        max_concurrent_steps=2,
        delay_release=1,
    )

    pipeline(a_gmem, b_gmem)
    # Store WGMMA accumulator to SMEM and then to GMEM.
    o_smem[...] = acc[...].astype(dtype)
    plgpu.commit_smem()
    m_slice = pl.ds(pid_m * tile_m, tile_m)
    n_slice = pl.ds(pid_n * tile_n, tile_n)
    plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])
    plgpu.wait_smem_to_gmem(0)

  return plgpu.kernel(
      kernel,
      out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
      scratch_shapes=[
          plgpu.SMEM((tile_m, tile_n), jnp.float16),
          plgpu.ACC((tile_m, tile_n), jnp.float32)
          ],
      # grid specifies the CUDA grid.
      # Instances of `kernel` will be executed in parallel over this grid.
      grid=(grid_m, grid_n),
      grid_names=("m", "n"),
  )(a, b)

m = 132 * 128
n = 4 * 128
k = 10 * 64
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)

result = matmul(a, b)

np.testing.assert_allclose(result, a @ b)

Warp 专用化#

Warp 专用化是一种技术,我们对每个 warp/warpgroup 进行编程,使其执行单个任务,以便为 GPU 硬件提供运行时调度的灵活性。回想一下,GPU 中的每个流式多处理器 (SM) 都包含 warp 调度器,可以在 warp 之间交换执行,例如,当一个 warp 停滞时,它可以开始执行另一个 warp。实际上,这可能比编程单个指令流更具性能,因为编译器必须静态调度操作并尝试最佳地重叠它们。

特别地,我们对 Hopper+ GPU 上的 warpgroup 专用化感兴趣,在这些 GPU 上,让一个独立的 warpgroup 负责 TMA(GMEM/SMEM 复制)操作,而其他 warpgroup 执行算术操作可能很有用,因为索引计算和发出 TMA 会占用大量时间并可能导致 TensorCore 空闲。下图左侧描绘了一个标准的、非专用化的内核,其中 TMA(异步复制)和矩阵乘法由单个指令流发出;右侧则是一个 warp 专用化的版本,其中通信和算术由独立的 warpgroup 处理。一个消耗屏障用于专门化的 warpgroup 之间进行同步,它向内存 warpgroup 发出信号,告知何时可以开始下一次 TMA。

在 Pallas 中,可以通过使用 plgpu.emit_pipeline_warp_specialized 助手来启用 Warp 专用化。这个流水线助手处理内存线程中的所有逻辑,用户只需指定计算线程中完成的工作。它与标准 emit_pipeline 共享类似的 API,目前支持以下参数:

plgpu.emit_pipeline_warp_specialized(
  body: Callable,
  *
  grid: tuple[int, ...],
  in_specs: Sequence[pallas_core.BlockSpec] = (),
  out_specs: Sequence[pallas_core.BlockSpec] = (),
  max_concurrent_steps: int,
  compute_context: Callable
  num_compute_wgs: int,
  memory_registers: int
  wg_axis: str,
  memory_thread_idx: int | None = None,
)

此流水线发射器特有的几个参数是:

  • num_compute_wgs 指定要使用的计算线程/warpgroup 的数量。流水线发射器始终使用单个内存线程,因此在 plgpu.kernel 中,您应该指定 num_threads=num_compute_wgs+1

  • memory_registers 控制分配给内存线程的寄存器数量。剩余的寄存器在计算线程之间平均分配。默认值为 40,应根据是否遇到寄存器溢出进行调整。

  • wg_axis 线程/warpgroup 轴的名称(由 plgpu.kernelthead_name 参数指定)。

  • memory_thread_idx 指定哪个 Pallas 线程被指定为内存线程。默认为最后一个线程。

  • compute_context 允许您指定仅在计算线程中运行的流水线序言/尾声。此函数允许您定义在流水线中循环进位的初始化和消耗。所有计算线程特定的数组都应在此处实例化,以便内存线程不会在寄存器中实例化它们——否则,您可能会由于寄存器溢出而遇到性能下降。

warp 专用化流水线的流水线主体由所有计算线程并行运行,并且 SMEM 在计算线程之间共享,因为它们在同一个 CUDA 块中进行调度。lax.axis_index 可用于内核内部以获取 Pallas 线程索引,从而在计算线程之间分配工作。

示例:使用 Warp 专用化的矩阵乘法#

以下示例扩展了先前的矩阵乘法示例,以使用 warp 专用化。此特定内核使用 2 个计算线程,它们操作 RHS 矩阵的不同列,但共享相同的 LHS。因此,流水线的每次调用都会在输出矩阵中计算 2 个相邻块。

我们使用 compute_context 模式来初始化 WGMMA 累加器,并将最终累加器从寄存器复制到 SMEM。在这里,计算上下文在函数 compute_thread 中定义。至关重要的是,累加器必须在 compute_thread 函数内部创建,以避免在内存线程中分配它而浪费寄存器。为了执行 WGMMA,我们将 wgmma 指令封装在 pl.run_state 中,以创建一个初始化为进位值的累加器引用。

我们没有使用 pl.pallas_call 来调用内核,而是使用了 GPU 特定的 plgpu.kernel 入口点。plgpu.kernel 允许我们通过 num_threads 参数指定每个 CUDA 块要启动的线程数,并允许我们指定一个 thread_name,我们可以在内核内部使用它来查询 Pallas 线程索引。

def matmul_warp_specialized(a, b, tile_m=128, tile_n=128, swizzle=128,
                            compute_wgs=2):
  dtype = jnp.float16
  elems_128b = swizzle // jnp.dtype(dtype).itemsize
  tile_k = elems_128b
  grid_m = m // tile_m
  grid_k = k // tile_k
  grid_n = n // tile_n
  assert tile_m % elems_128b == 0

  transforms = (
          plgpu.TilingTransform((8, elems_128b)),
          plgpu.SwizzleTransform(128),
      )

  def kernel(a_gmem, b_gmem, o_gmem, o_smem):
    wg_idx = lax.axis_index("wg")
    wg_slice = pl.ds(wg_idx * tile_n, tile_n)
    # pl.program_id obtains the index into the pallas_call grid.
    pid_m = pl.program_id(0)
    pid_n = pl.program_id(1)

    def compute_thread(pipeline):
      acc = plgpu.layout_cast(
          jnp.full((tile_m, tile_n), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
      )
      # yield marks the place where the pipelined loop will be inserted.
      # Its argument are the initial carry values, and its result is the carry
      # value after the loop completes.
      final_acc = pipeline(acc)
      o_smem[:, wg_slice] = final_acc[...].astype(dtype)

    def kernel_body(_, a_smem, b_smem, carry):
      acc = carry
      b_smem_wg = b_smem.at[:, wg_slice]
      def do_wgmma(acc_ref):
        plgpu.wgmma(acc_ref, a_smem, b_smem_wg)
      acc = pl.run_state(do_wgmma)(
                          plgpu.ACC.init(acc))
      return acc

    pipeline = plgpu.emit_pipeline_warp_specialized(
        kernel_body,
        in_specs=[
            plgpu.BlockSpec(
              (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms
            ),
            plgpu.BlockSpec(
              (tile_k, tile_n * 2), lambda k: (k, pid_n),transforms=transforms
            ),
        ],
        grid=(grid_k,),
        compute_context=compute_thread,
        max_concurrent_steps=2,
        num_compute_wgs=compute_wgs,
        memory_registers=40,
        memory_thread_idx=2,
        wg_axis="wg",
    )
    # Call the pipeline
    pipeline(a_gmem, b_gmem)
    # Copy the output from SMEM to GMEM.
    plgpu.commit_smem()
    m_slice = pl.ds(pid_m * tile_m, tile_m)
    n_slice = pl.ds(pid_n * tile_n * 2, tile_n * 2)
    plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])
    plgpu.wait_smem_to_gmem(0)

  return plgpu.kernel(
      kernel,
      out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
      scratch_shapes=[
          plgpu.SMEM((tile_m, tile_n * 2), jnp.float16)
          ],
      grid=(grid_m, grid_n // 2),
      grid_names=("m", "n"),
      num_threads=3,  # 2 compute, 1 memory.
      thread_name="wg"
  )(a, b)

m = 132 * 128
n = 4 * 128
k = 10 * 64
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)

result = matmul_warp_specialized(a, b)

np.testing.assert_allclose(result, a @ b)