Mosaic GPU 流水线#
本指南介绍了使用 Pallas 的 Mosaic GPU 后端进行软件流水线操作。
有关 Pallas 中流水线 API 的一般概述,我们建议用户首先阅读软件流水线。Pallas 中的流水线是显式编程的。对于熟悉 Triton 的用户来说,这是一个显著的编程模型差异,因为在 Triton 中,流水线是编译器自动完成的优化。
import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental.pallas import mosaic_gpu as plgpu
from jax.experimental import pallas as pl
import numpy as np
使用 Mosaic GPU 进行流水线操作#
使用 Mosaic GPU 进行流水线操作的推荐方法是使用 plgpu.emit_pipeline
函数对顺序循环进行流水线操作(并使用 plgpu.kernel
在 CUDA 网格中并行划分问题)。emit_pipeline
遵循与 pl.pallas_call
类似的 API,只是它公开了一些额外的 GPU 特定选项。
body
和grid
具有与pl.pallas_call
中类似的语义。grid
表示要运行body
函数的调用次数。与 CUDA 网格不同,流水线网格保证按顺序运行。in_specs
和out_specs
的工作方式也与pl.pallas_call
类似,只是它们还接受plgpu.BlockSpec
实例,可用于指定 GPU 特定的转换,例如交错。有关可用转换的更多详细信息,请参阅内存引用转换。max_concurrent_steps
控制最大并发内存传输数量。使用额外的并发步骤将消耗更多 SMEM 来保存临时缓冲区,但可以提高内存子系统的利用率。我们建议对该参数进行自动调优。较低的值(例如 2)有时可以实现更高的占用率(由于较低的 SMEM 使用),这可以提高 ALU 密集型内核的吞吐量,但由于硬件负责调度,会引入更多噪声。较大的值(4 到 6 之间)对于无法利用额外占用率的内核效果最佳。delay_release
允许用户指定在缓冲区被流水线重用之前额外等待的迭代次数。例如,在迭代 0 上复制到 SMEM 的缓冲区,如果delay_release=1
且max_concurrent_steps=2
,将直到迭代 3 才被重用,而标准双缓冲策略下是迭代 2。delay_release=1
是必要的,如果你不对流水线操作数上的plgpu.wgmma
操作进行等待,否则流水线将在 WGMMA 仍在读取缓冲区时开始覆盖它们。这对于某些优化很有用,例如允许多个异步矩阵乘法在进行中以保持张量核心流水线的填充,但使用这种策略时必须小心,因为省略此参数会导致静默数据竞争,并且会降低emit_pipeline
的效率,因为我们重叠的内存传输更少。
使用 pl.pallas_call
的兼容性 API#
作为 emit_pipeline
的替代方案,并为了保持与 Pallas TPU 的兼容性,Mosaic GPU 还实现了现有的 pl.pallas_call
API。默认情况下,Mosaic GPU 上的 pl.pallas_call
将在 CUDA 网格中并行划分您的内核。您可以通过传递 plgpu.GPUCompilerParams
对象作为 compiler_params
参数来选择启用流水线,该对象指定了与流水线相关的以下选项:
dimension_semantics
:一个由Literal['parallel', 'sequential']
组成的元组,用于指定每个网格维度的迭代语义。parallel
将在 CUDA 网格中划分相应的维度,而sequential
维度将按顺序进行流水线处理。请注意,如果没有维度被标记为sequential
,则不会发生流水线!max_concurrent_steps
:与plgpu.emit_pipeline
中的选项相同。delay_release
:与plgpu.emit_pipeline
中的选项相同。
流水线允许您在网格的顺序迭代中重用暂存缓冲区(例如,用于实现归约)。此外,当使用 Mosaic GPU 后端时,pallas_call
支持使用 plgpu.BlockSpec
对象代替 pl.BlockSpec
对象,从而允许您指定 GPU 特定的内存转换。
我们建议用户使用 plgpu.kernel
而不是 pl.pallas_call
,因为 plgpu.kernel
支持更多功能(例如指定 warpgroup 的数量和 warp 专用化)。
GPU 内存空间#
Ref 主要存在于两种内存空间之一,这可以通过 BlockSpec
的 memory_space
参数显式指定,即 BlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)
。
plgpu.GPUMemorySpace.SMEM
在共享内存 (SMEM) 中分配一个 Ref。SMEM Ref 可以使用数组索引语法进行解引用,以将值存储在寄存器中进行计算,即x = y_ref[...]
。当使用emit_pipeline
时,此内存空间用于 Ref。plgpu.GPUMemorySpace.GMEM
在全局内存 (GMEM/HBM) 中分配一个 Ref。任何在 GMEM 中分配的 Ref 都不会进行流水线处理,并且无法直接使用数组索引操作访问值。相反,GMEM 必须通过 SMEM 访问,使用plgpu.copy_gmem_to_smem
进行读取,或使用plgpu.copy_smem_to_gmem
进行写入,或使用plgpu.emit_pipeline
流水线到 SMEM 中。
emit_pipeline
的主要目的是将 TensorCore 计算与 GMEM 和 SMEM 之间的数据传输重叠,因为 GMEM/SMEM 之间的异步复制具有较长的延迟,但所有 TensorCore 计算都必须在寄存器上操作(或者在矩阵乘法的情况下,在 SMEM Ref 上操作)。
示例:Hopper GPU 上的矩阵乘法内核#
我们从一个专为 Hopper GPU 运行的矩阵乘法示例开始。该内核利用 Hopper 特定的 wgmma
(warpgroup 矩阵乘累加)指令。wgmma
由单个 Mosaic GPU 线程发出,并在 TensorCore 上异步运行。
我们的示例内核实现了两个形状为 [M, K] @ [K, N] = [M, N]
的矩阵的块式矩阵乘法,其中每个输出块在 CUDA 网格上并行计算。此网格被指定为外部 plgpu.kernel
的 grid
参数,并在矩阵乘法的非收缩维度 M、N 上并行化。
在程序实例中,我们使用 plgpu.emit_pipeline
运行顺序流水线,该流水线对矩阵乘法的收缩维度 K 进行归约。在流水线的每次迭代中,我们从每个输入矩阵加载一个分块,进行乘法运算,然后将结果存储在累加器 Ref(plgpu.ACC
)中。plgpu.ACC
是一种特殊类型的 Ref,它存在于寄存器中,并保存 WGMMA 的中间结果。一旦我们在整个收缩维度上累加完毕,就将结果写入输出 Ref。
为了执行实际的矩阵乘法,我们调用 plgpu.wgmma
,并将累加器、LHS 和 RHS Ref 作为参数,以便将参数推送到 TensorCore 流水线中。所有 WGMMA 操作都按顺序执行,因此这可以看作是将操作推入队列。由于 wgmma
是一个异步指令,plgpu.wgmma_wait(N)
用于等待,直到飞行中的 wgmma
操作不超过 N 个。在此特定实现中,我们等待 1 个飞行中的 WGMMA,这意味着我们在当前迭代中排队的 WGMMA 将在下一次迭代中被等待。
wgmma
要求其参数采用CUDA 文档中定义的特定格式。这些格式由输入 BlockSpec 上的TilingTransform
和SwizzleTransform
转换实现。请注意,将来 Mosaic GPU 将自动推断转换,无需手动指定。有关使用此指令的完整详细信息,请参阅 wgmma 参考。我们结合使用
delay_release
参数和plgpu.wgmma_wait(1)
,始终允许一个WGMMA
操作保持在飞行中,以确保 TensorCore 的良好利用率。如果没有这个,我们将在内核的每次迭代中刷新 TensorCore 流水线。
def matmul(a, b, tile_m=128, tile_n=128, swizzle=128):
dtype = jnp.float16
swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
tile_k = swizzle_elems
grid_m = m // tile_m
grid_k = k // tile_k
grid_n = n // tile_n
assert tile_m % swizzle_elems == 0
# Note: Transforms will be inferred automatically
# by Mosaic GPU in the future.
transforms = (
plgpu.TilingTransform((8, swizzle_elems)),
plgpu.SwizzleTransform(swizzle),
)
def kernel(a_gmem, b_gmem, o_gmem, o_smem, acc):
def pipeline_step(_, a_smem, b_smem):
plgpu.wgmma(acc, a_smem, b_smem)
plgpu.wgmma_wait(1)
# pl.program_id obtains the index into the grid.
pid_m = pl.program_id(0)
pid_n = pl.program_id(1)
pipeline = plgpu.emit_pipeline(
pipeline_step,
in_specs=[
plgpu.BlockSpec(
(tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms
),
plgpu.BlockSpec(
(tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms
),
],
grid=(grid_k,),
max_concurrent_steps=2,
delay_release=1,
)
pipeline(a_gmem, b_gmem)
# Store WGMMA accumulator to SMEM and then to GMEM.
o_smem[...] = acc[...].astype(dtype)
plgpu.commit_smem()
m_slice = pl.ds(pid_m * tile_m, tile_m)
n_slice = pl.ds(pid_n * tile_n, tile_n)
plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])
plgpu.wait_smem_to_gmem(0)
return plgpu.kernel(
kernel,
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
scratch_shapes=[
plgpu.SMEM((tile_m, tile_n), jnp.float16),
plgpu.ACC((tile_m, tile_n), jnp.float32)
],
# grid specifies the CUDA grid.
# Instances of `kernel` will be executed in parallel over this grid.
grid=(grid_m, grid_n),
grid_names=("m", "n"),
)(a, b)
m = 132 * 128
n = 4 * 128
k = 10 * 64
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)
result = matmul(a, b)
np.testing.assert_allclose(result, a @ b)
Warp 专用化#
Warp 专用化是一种技术,我们对每个 warp/warpgroup 进行编程,使其执行单个任务,以便为 GPU 硬件提供运行时调度的灵活性。回想一下,GPU 中的每个流式多处理器 (SM) 都包含 warp 调度器,可以在 warp 之间交换执行,例如,当一个 warp 停滞时,它可以开始执行另一个 warp。实际上,这可能比编程单个指令流更具性能,因为编译器必须静态调度操作并尝试最佳地重叠它们。
特别地,我们对 Hopper+ GPU 上的 warpgroup 专用化感兴趣,在这些 GPU 上,让一个独立的 warpgroup 负责 TMA(GMEM/SMEM 复制)操作,而其他 warpgroup 执行算术操作可能很有用,因为索引计算和发出 TMA 会占用大量时间并可能导致 TensorCore 空闲。下图左侧描绘了一个标准的、非专用化的内核,其中 TMA(异步复制)和矩阵乘法由单个指令流发出;右侧则是一个 warp 专用化的版本,其中通信和算术由独立的 warpgroup 处理。一个消耗屏障用于专门化的 warpgroup 之间进行同步,它向内存 warpgroup 发出信号,告知何时可以开始下一次 TMA。
在 Pallas 中,可以通过使用 plgpu.emit_pipeline_warp_specialized
助手来启用 Warp 专用化。这个流水线助手处理内存线程中的所有逻辑,用户只需指定计算线程中完成的工作。它与标准 emit_pipeline
共享类似的 API,目前支持以下参数:
plgpu.emit_pipeline_warp_specialized(
body: Callable,
*
grid: tuple[int, ...],
in_specs: Sequence[pallas_core.BlockSpec] = (),
out_specs: Sequence[pallas_core.BlockSpec] = (),
max_concurrent_steps: int,
compute_context: Callable
num_compute_wgs: int,
memory_registers: int
wg_axis: str,
memory_thread_idx: int | None = None,
)
此流水线发射器特有的几个参数是:
num_compute_wgs
指定要使用的计算线程/warpgroup 的数量。流水线发射器始终使用单个内存线程,因此在plgpu.kernel
中,您应该指定num_threads=num_compute_wgs+1
。memory_registers
控制分配给内存线程的寄存器数量。剩余的寄存器在计算线程之间平均分配。默认值为 40,应根据是否遇到寄存器溢出进行调整。wg_axis
线程/warpgroup 轴的名称(由plgpu.kernel
的thead_name
参数指定)。memory_thread_idx
指定哪个 Pallas 线程被指定为内存线程。默认为最后一个线程。compute_context
允许您指定仅在计算线程中运行的流水线序言/尾声。此函数允许您定义在流水线中循环进位的初始化和消耗。所有计算线程特定的数组都应在此处实例化,以便内存线程不会在寄存器中实例化它们——否则,您可能会由于寄存器溢出而遇到性能下降。
warp 专用化流水线的流水线主体由所有计算线程并行运行,并且 SMEM 在计算线程之间共享,因为它们在同一个 CUDA 块中进行调度。lax.axis_index
可用于内核内部以获取 Pallas 线程索引,从而在计算线程之间分配工作。
示例:使用 Warp 专用化的矩阵乘法#
以下示例扩展了先前的矩阵乘法示例,以使用 warp 专用化。此特定内核使用 2 个计算线程,它们操作 RHS 矩阵的不同列,但共享相同的 LHS。因此,流水线的每次调用都会在输出矩阵中计算 2 个相邻块。
我们使用 compute_context
模式来初始化 WGMMA 累加器,并将最终累加器从寄存器复制到 SMEM。在这里,计算上下文在函数 compute_thread
中定义。至关重要的是,累加器必须在 compute_thread
函数内部创建,以避免在内存线程中分配它而浪费寄存器。为了执行 WGMMA,我们将 wgmma
指令封装在 pl.run_state
中,以创建一个初始化为进位值的累加器引用。
我们没有使用 pl.pallas_call
来调用内核,而是使用了 GPU 特定的 plgpu.kernel
入口点。plgpu.kernel
允许我们通过 num_threads
参数指定每个 CUDA 块要启动的线程数,并允许我们指定一个 thread_name
,我们可以在内核内部使用它来查询 Pallas 线程索引。
def matmul_warp_specialized(a, b, tile_m=128, tile_n=128, swizzle=128,
compute_wgs=2):
dtype = jnp.float16
elems_128b = swizzle // jnp.dtype(dtype).itemsize
tile_k = elems_128b
grid_m = m // tile_m
grid_k = k // tile_k
grid_n = n // tile_n
assert tile_m % elems_128b == 0
transforms = (
plgpu.TilingTransform((8, elems_128b)),
plgpu.SwizzleTransform(128),
)
def kernel(a_gmem, b_gmem, o_gmem, o_smem):
wg_idx = lax.axis_index("wg")
wg_slice = pl.ds(wg_idx * tile_n, tile_n)
# pl.program_id obtains the index into the pallas_call grid.
pid_m = pl.program_id(0)
pid_n = pl.program_id(1)
def compute_thread(pipeline):
acc = plgpu.layout_cast(
jnp.full((tile_m, tile_n), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
)
# yield marks the place where the pipelined loop will be inserted.
# Its argument are the initial carry values, and its result is the carry
# value after the loop completes.
final_acc = pipeline(acc)
o_smem[:, wg_slice] = final_acc[...].astype(dtype)
def kernel_body(_, a_smem, b_smem, carry):
acc = carry
b_smem_wg = b_smem.at[:, wg_slice]
def do_wgmma(acc_ref):
plgpu.wgmma(acc_ref, a_smem, b_smem_wg)
acc = pl.run_state(do_wgmma)(
plgpu.ACC.init(acc))
return acc
pipeline = plgpu.emit_pipeline_warp_specialized(
kernel_body,
in_specs=[
plgpu.BlockSpec(
(tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms
),
plgpu.BlockSpec(
(tile_k, tile_n * 2), lambda k: (k, pid_n),transforms=transforms
),
],
grid=(grid_k,),
compute_context=compute_thread,
max_concurrent_steps=2,
num_compute_wgs=compute_wgs,
memory_registers=40,
memory_thread_idx=2,
wg_axis="wg",
)
# Call the pipeline
pipeline(a_gmem, b_gmem)
# Copy the output from SMEM to GMEM.
plgpu.commit_smem()
m_slice = pl.ds(pid_m * tile_m, tile_m)
n_slice = pl.ds(pid_n * tile_n * 2, tile_n * 2)
plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])
plgpu.wait_smem_to_gmem(0)
return plgpu.kernel(
kernel,
out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),
scratch_shapes=[
plgpu.SMEM((tile_m, tile_n * 2), jnp.float16)
],
grid=(grid_m, grid_n // 2),
grid_names=("m", "n"),
num_threads=3, # 2 compute, 1 memory.
thread_name="wg"
)(a, b)
m = 132 * 128
n = 4 * 128
k = 10 * 64
key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)
result = matmul_warp_specialized(a, b)
np.testing.assert_allclose(result, a @ b)