Mosaic GPU 流水线#
本指南涵盖了使用 Pallas 的 Mosaic GPU 后端进行软件流水线化的相关内容。
关于 Pallas 流水线 API 的概述,建议用户先阅读 软件流水线。Pallas 中的流水线是显式编程的。对于熟悉 Triton 的用户来说,这是一个显著的编程模型差异,因为在 Triton 中,流水线是由编译器自动执行的优化。
import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental.pallas import mosaic_gpu as plgpu
from jax.experimental import pallas as pl
import numpy as np
Mosaic GPU 流水线#
使用 Mosaic GPU 进行流水线化的推荐方法是使用 plgpu.emit_pipeline 函数对顺序循环进行流水线化(并使用 plgpu.kernel 在 CUDA 网格上并行划分任务)。emit_pipeline 遵循与 pl.pallas_call 类似的 API,区别在于它提供了一些额外的 GPU 特定选项。
body和grid的语义与pl.pallas_call中的类似。grid表示要运行多少次body函数。与 CUDA 网格不同,流水线网格保证按顺序运行。in_specs和out_specs的工作方式也与pl.pallas_call类似,区别在于它们还接受plgpu.BlockSpec实例,可用于指定 GPU 特定的转换,例如数据重排(swizzling)。有关可用转换的更多详细信息,请参阅 内存引用转换。max_concurrent_steps控制并发内存传输的最大数量。使用额外的并发步骤会消耗更多的共享内存(SMEM)来存放临时缓冲区,但可以提高内存子系统的利用率。我们建议对此参数进行自动调优。较低的值(例如 2)有时可以实现更高的占用率(因为 SMEM 使用量较少),这可以提高 ALU 密集型 Kernel 的吞吐量,但会因为硬件负责调度而引入更多噪声。较大的值(在 4 到 6 之间)最适合无法利用额外占用率的 Kernel。delay_release允许用户指定在缓冲区被流水线重用之前额外等待的迭代次数。例如,在delay_release=1和max_concurrent_steps=2的情况下,在迭代 0 时复制到 SMEM 的缓冲区直到迭代 3 才会被重用,而不是标准双缓冲策略中的迭代 2。如果您不对流水线操作数执行plgpu.wgmma等待操作,则必须使用delay_release=1,否则流水线会在 WGMMA 仍在读取缓冲区时就开始覆盖它们。这对于某些优化(例如允许同时进行多个异步矩阵乘法以保持 Tensor Core 流水线填满)很有用,但使用此类策略时必须小心,因为省略此参数会导致静默数据竞争,并且它会降低emit_pipeline的效率,因为我们重叠的内存传输更少了。
使用 pl.pallas_call 的兼容性 API#
作为 emit_pipeline 的替代方案,为了保持与 Pallas TPU 的兼容性,Mosaic GPU 也实现了现有的 pl.pallas_call API。默认情况下,Mosaic GPU 上的 pl.pallas_call 将在 CUDA 网格上并行划分您的 Kernel。您可以通过传入 plgpu.CompilerParams 对象作为 compiler_params 参数来选择流水线化,该对象指定了以下与流水线化相关的选项。
dimension_semantics:由Literal['parallel', 'sequential']组成的元组,指定每个网格维度的迭代语义。parallel会在 CUDA 网格上划分相应维度,而sequential维度将按顺序进行流水线化。注意,如果没有标记为sequential的维度,则不会发生流水线化!max_concurrent_steps:与plgpu.emit_pipeline中的选项相同。delay_release:与plgpu.emit_pipeline中的选项相同。
流水线允许您在网格的顺序迭代中重用暂存缓冲区(例如用于实现归约)。此外,当使用 Mosaic GPU 后端时,pallas_call 支持使用 plgpu.BlockSpec 对象代替 pl.BlockSpec 对象,允许您指定 GPU 特定的内存转换。
我们建议用户使用 plgpu.kernel 而不是 pl.pallas_call,因为 plgpu.kernel 支持更多功能(例如指定 warp 组数量和 Warp 专用化)。
GPU 内存空间#
引用(Refs)主要存在于两个内存空间之一,可以通过 BlockSpec 的 memory_space 参数显式指定,即 BlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)。
plgpu.GPUMemorySpace.SMEM在共享内存(SMEM)中分配一个 Ref。SMEM Refs 可以使用数组索引语法进行解引用,以将值存储在寄存器中进行计算,例如x = y_ref[...]。这是使用emit_pipeline时 Ref 所使用的内存空间。plgpu.GPUMemorySpace.GMEM在全局内存(GMEM/HBM)中分配一个 Ref。在 GMEM 中分配的任何 Refs 都不会被流水线化,且无法通过数组索引操作直接访问值。相反,GMEM 必须通过 SMEM 访问:读取时使用plgpu.copy_gmem_to_smem,写入时使用plgpu.copy_smem_to_gmem,或者使用plgpu.emit_pipeline流水线化到 SMEM 中。
emit_pipeline 的主要目的是将 TensorCore 计算与 GMEM 和 SMEM 之间的数据传输重叠,因为 GMEM/SMEM 之间的异步复制具有很长的延迟,而所有 TensorCore 计算都必须在寄存器(或矩阵乘法情况下的 SMEM Refs)上进行。
示例:Hopper GPU 上的矩阵乘法 Kernel#
让我们从一个旨在在 Hopper GPU 上运行的矩阵乘法示例开始。此 Kernel 利用了 Hopper 特有的 wgmma(Warp 组矩阵乘加)指令。wgmma 由单个 Mosaic GPU 线程发出,并在 TensorCore 上异步运行。
我们的示例 Kernel 实现了两个形状为 [M, K] @ [K, N] = [M, N] 的矩阵的块矩阵乘法,其中每个输出块都在 CUDA 网格上并行计算。此网格被指定为外部 plgpu.kernel 的 grid 参数,并在矩阵乘法的非缩约维度 M、N 上进行并行化。
在程序实例中,我们使用 plgpu.emit_pipeline 运行一个顺序流水线,该流水线在矩阵乘法的缩约维度 K 上进行归约。在流水线的每次迭代中,我们从每个输入矩阵加载一个 tile,将它们相乘,然后将结果存储在累加器 Ref (plgpu.ACC) 中。plgpu.ACC 是一种特殊的 Ref 类型,它位于寄存器中并保存 WGMMA 的中间结果。一旦我们在整个缩约维度上完成了累加,我们将结果写出到输出 Ref 中。
为了执行实际的矩阵乘法,我们以累加器、LHS 和 RHS Refs 作为参数调用 plgpu.wgmma,以将参数推送到 TensorCore 流水线中。所有 WGMMA 操作都按顺序执行,因此这可以看作是将操作推入队列。由于 wgmma 是异步指令,plgpu.wgmma_wait(N) 用于等待直到正在进行的 wgmma 操作不超过 N 个。在这个特定的实现中,我们等待 1 个正在进行的 WGMMA,这意味着我们在当前迭代中排队的 WGMMA 将在下一次迭代中被等待。
wgmma要求其参数以特定的格式存在,定义在 CUDA 文档中。这些格式由输入 BlockSpecs 上的TilingTransform和SwizzleTransform转换实现。请注意,未来 Mosaic GPU 将自动推断转换,无需手动指定。有关使用此指令的完整详细信息,请参阅 wgmma 参考。我们将
delay_release参数与plgpu.wgmma_wait(1)结合使用,始终允许一个WGMMA操作处于进行中,以确保良好的 TensorCore 利用率。如果没有这一点,我们将在 Kernel 的每次迭代中刷新 TensorCore 流水线。
def matmul(a, b, tile_m=128, tile_n=128, swizzle=128):
dtype = jnp.float16
swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
tile_k = swizzle_elems
grid_m = m // tile_m
grid_k = k // tile_k
grid_n = n // tile_n
assert tile_m % swizzle_elems == 0
# Note: Transforms will be inferred automatically
# by Mosaic GPU in the future.
transforms = (
plgpu.TilingTransform((8, swizzle_elems)),
plgpu.SwizzleTransform(swizzle),
)
def kernel(a_gmem, b_gmem, o_gmem, o_smem, acc):
def pipeline_step(_, a_smem, b_smem):
plgpu.wgmma(acc, a_smem, b_smem)
plgpu.wgmma_wait(1)
# pl.program_id obtains the index into the grid.
pid_m = pl.program_id(0)
pid_n = pl.program_id(1)
pipeline = plgpu.emit_pipeline(
pipeline_step,
in_specs=[
plgpu.BlockSpec(
(tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms
),
plgpu.BlockSpec(
(tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms
),
],
grid=(grid_k,),
max_concurrent_steps=2,
delay_release=1,
)
pipeline(a_gmem, b_gmem)
# Store WGMMA accumulator to SMEM and then to GMEM.
o_smem[...] = acc[...].astype(dtype)
plgpu.commit_smem()
m_slice = pl.ds(pid_m * tile_m, tile_m)
n_slice = pl.ds(pid_n * tile_n, tile_n)
plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])
plgpu.wait_smem_to_gmem(0)
return plgpu.kernel(
kernel,
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
scratch_shapes=dict(
o_smem=plgpu.SMEM((tile_m, tile_n), jnp.float16),
acc=plgpu.ACC((tile_m, tile_n), jnp.float32)
),
# grid specifies the CUDA grid.
# Instances of `kernel` will be executed in parallel over this grid.
grid=(grid_m, grid_n),
grid_names=("m", "n"),
)(a, b)
m = 132 * 128
n = 4 * 128
k = 10 * 64
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)
result = matmul(a, b)
np.testing.assert_allclose(result, a @ b)
Warp 专用化 (Warp Specialization)#
Warp 专用化是一种技术,我们对每个 warp/warpgroup 进行编程以执行单个任务,从而赋予 GPU 硬件在运行时灵活调度它们的能力。回想一下,GPU 中的每个流式多处理器 (SM) 都包含 warp 调度器,可以在 warp 之间切换执行,例如当一个 warp 停顿时,它可以开始执行另一个 warp。在实践中,这比编写单个指令流更具性能,因为在单个指令流中,编译器必须静态调度操作并尝试以最佳方式重叠它们。
特别是,我们对 Hopper+ GPU 上的 Warp 组专用化感兴趣,让一个独立的 Warp 组发出 TMA(GMEM/SMEM 复制)与执行算术运算的 Warp 组分开是有益的,因为索引计算和发出 TMA 可能占用大量时间,并可能导致 TensorCore 空闲。下图描绘了左侧的标准非专用化 Kernel,其中 TMA(异步复制)和矩阵乘法从单个指令流发出,右侧是 Warp 专用化版本,其中通信和算术在不同的 Warp 组上处理。使用消费栅栏 (consumed barrier) 在专用 Warp 组之间进行同步,这会向内存 Warp 组发出信号,表明何时可以安全地开始下一个 TMA。
可以通过使用 plgpu.emit_pipeline_warp_specialized 辅助程序在 Pallas 中启用 Warp 专用化。此流水线辅助程序处理内存线程中的所有逻辑,用户只需要指定计算线程中完成的工作。它与标准 emit_pipeline 共享相似的 API,目前支持以下参数。
plgpu.emit_pipeline_warp_specialized(
body: Callable,
*
grid: tuple[int, ...],
in_specs: Sequence[pallas_core.BlockSpec] = (),
out_specs: Sequence[pallas_core.BlockSpec] = (),
max_concurrent_steps: int,
compute_context: Callable
num_compute_wgs: int,
memory_registers: int
wg_axis: str,
memory_thread_idx: int | None = None,
)
此流水线发射器有一些特定参数,它们是:
num_compute_wgs指定要使用的计算线程/Warp 组数量。流水线发射器总是使用单个内存线程,因此在plgpu.kernel中,您应该指定num_threads=num_compute_wgs+1。memory_registers控制分配给内存线程的寄存器数量。剩余的寄存器在计算线程之间平均分配。默认值为 40,应根据是否遇到寄存器溢出(spill)向上或向下调整。wg_axis是线程/Warp 组轴的名称(由plgpu.kernel的thread_name参数指定)。memory_thread_idx指定将哪个 Pallas 线程指定为内存线程。默认为最后一个线程。compute_context使您能够指定流水线的序幕/尾声(prologue/epilogue),这些代码仅在计算线程中运行。该函数允许您定义流水线中循环进位(loop carry)的初始化和消费。所有计算线程特定的数组都应在此处实例化,以便内存线程不会在寄存器中实现它们——否则,您可能会因寄存器溢出而遇到减速。
Warp 专用化流水线的流水线主体由所有计算线程并行运行,由于它们在同一个 CUDA 块内调度,因此共享 SMEM。lax.axis_index 可在 Kernel 内部用于获取 Pallas 线程索引,以便在计算线程之间分配工作。
示例:使用 Warp 专用化的矩阵乘法#
以下示例扩展了之前的矩阵乘法示例以使用 Warp 专用化。此特定 Kernel 使用 2 个计算线程,它们对 RHS 矩阵的不同列进行操作,但共享相同的 LHS。因此,每次调用流水线都会计算输出矩阵中的 2 个相邻块。
我们使用 compute_context 模式来初始化 WGMMA 累加器,并将最终累加器从寄存器复制到 SMEM 中。在这里,计算上下文在函数 compute_thread 中定义。累加器必须在 compute_thread 函数内部创建,这一点至关重要,以避免将其分配在内存线程中,这将浪费寄存器。为了执行 WGMMA,我们将 wgmma 指令包装在 pl.run_state 中,以创建初始化为进位值的累加器 ref。
我们没有使用 pl.pallas_call 来调用 Kernel,而是使用 GPU 特定的 plgpu.kernel 入口点。plgpu.kernel 允许我们通过 num_threads 参数指定每个 CUDA 块启动的线程数,并允许我们指定一个 thread_name,我们可以在 Kernel 内部使用它来查询 Pallas 线程索引。
def matmul_warp_specialized(a, b, tile_m=128, tile_n=128, swizzle=128,
compute_wgs=2):
dtype = jnp.float16
elems_128b = swizzle // jnp.dtype(dtype).itemsize
tile_k = elems_128b
grid_m = m // tile_m
grid_k = k // tile_k
grid_n = n // tile_n
assert tile_m % elems_128b == 0
transforms = (
plgpu.TilingTransform((8, elems_128b)),
plgpu.SwizzleTransform(128),
)
def kernel(a_gmem, b_gmem, o_gmem, o_smem):
wg_idx = lax.axis_index("wg")
wg_slice = pl.ds(wg_idx * tile_n, tile_n)
# pl.program_id obtains the index into the pallas_call grid.
pid_m = pl.program_id(0)
pid_n = pl.program_id(1)
def compute_thread(pipeline):
acc = plgpu.layout_cast(
jnp.full((tile_m, tile_n), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
)
# yield marks the place where the pipelined loop will be inserted.
# Its argument are the initial carry values, and its result is the carry
# value after the loop completes.
final_acc = pipeline(acc)
o_smem[:, wg_slice] = final_acc[...].astype(dtype)
def kernel_body(_, a_smem, b_smem, carry):
acc = carry
b_smem_wg = b_smem.at[:, wg_slice]
def do_wgmma(acc_ref):
plgpu.wgmma(acc_ref, a_smem, b_smem_wg)
acc = pl.run_state(do_wgmma)(
plgpu.ACC.init(acc))
return acc
pipeline = plgpu.emit_pipeline_warp_specialized(
kernel_body,
in_specs=[
plgpu.BlockSpec(
(tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms
),
plgpu.BlockSpec(
(tile_k, tile_n * 2), lambda k: (k, pid_n),transforms=transforms
),
],
grid=(grid_k,),
compute_context=compute_thread,
max_concurrent_steps=2,
num_compute_wgs=compute_wgs,
memory_registers=40,
memory_thread_idx=2,
wg_axis="wg",
)
# Call the pipeline
pipeline(a_gmem, b_gmem)
# Copy the output from SMEM to GMEM.
plgpu.commit_smem()
m_slice = pl.ds(pid_m * tile_m, tile_m)
n_slice = pl.ds(pid_n * tile_n * 2, tile_n * 2)
plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])
plgpu.wait_smem_to_gmem(0)
return plgpu.kernel(
kernel,
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
scratch_shapes=dict(
o_smem=plgpu.SMEM((tile_m, tile_n * 2), jnp.float16)
),
grid=(grid_m, grid_n // 2),
grid_names=("m", "n"),
num_threads=3, # 2 compute, 1 memory.
thread_name="wg"
)(a, b)
m = 132 * 128
n = 4 * 128
k = 10 * 64
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)
result = matmul_warp_specialized(a, b)
np.testing.assert_allclose(result, a @ b)