在 Blackwell 上编写高性能矩阵乘法核#
在本指南中,我们将逐步迭代矩阵乘法核。第一个实现将非常简单,但也非常慢。然而,只需几个简单的步骤,就可以将其修改为最先进的核,其性能可与 cuBLAS 和 CUTLASS 等高度优化的实现相媲美甚至超越。
警告
下表中显示的利用率可能与您在线看到的有所不同,但这些差异很可能可以通过不同的输入数据分布来解释。我们这里的所有基准测试都使用具有 iid 正常 float16 条目的数组,事实证明这是您可以选择的较慢的分布之一。通过更改 BENCHMARK
变量为 True
,运行 我们的测试文件,您可以自己重现这些数字。
简而言之:如果不指定输入数据分布,请勿相信 matmul 基准测试。
实现 |
TensorCore 利用率 |
cuBLAS 利用率百分比 |
---|---|---|
0. 基本核 |
37.62% |
59.4% |
1. Warp 特化 |
45.47% |
71.7% |
2. 块状结尾 |
55.82% |
88.1% |
3. 集合 (2CTA) MMA |
59.41% |
93.7% |
4. 持久核 |
61.46% |
97.0% |
5. 专用结尾 Warpgroup |
63.38% |
100.0% |
6. 网格分块 |
69.44% |
109.6% |
cuBLAS |
63.38% |
100.0% |
CUTLASS |
69.30% |
109.3% |
cuBLAS 基线是通过测量 jax.dot
的性能获得的。CUTLASS 性能是通过从以下 cutlass_profiler
调用中获取最佳结果来测量的(不包括稀疏 matmuls)
cutlass_profiler --dist=gaussian,mean:0,stddev:1,scale:-1 --output=results.csv --accumulator-type=f32 --m=4096 --k=4096 --n=8192 --kernels='*sm100*' --A=f16 --B=f16 --C=void --D=f16
在每一步,我们将展示核的完整实现,或者展示上一步和当前步骤中代码列表之间的差异。完整实现可以在 我们的测试文件 中找到。您还可以在 Pallas ops 包中找到一个完整的独立优化核实现:Pallas ops 包。
0. 基本核#
我们从一个简单的单 CTA(块)单 Warpgroup 示例开始。为了方便起见,我们将核的调优参数分离到一个单独的类中
@dataclasses.dataclass(frozen=True)
class TuningConfig:
tile_m: int
tile_n: int
tile_k: int
max_concurrent_steps: int
tile_m
、tile_n
和 tile_k
指定了流水线每一步执行的 matmul 的大小。通常,tile_k
理想情况下应等于 128 除以输入元素类型的字节宽度。max_concurrent_steps
指定了计算/内存流水线中的内存预取深度,这在其他实现中经常被称为阶段数。
核的实现首先是一些设置代码
def matmul0(a, b, config: TuningConfig):
dtype = a.dtype
m, k = a.shape
_, n = b.shape
tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k
swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8)
swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
transforms = (
plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle)
)
if m % tile_m != 0:
raise ValueError(f"{m=} must be divisible by {tile_m=}")
if n % tile_n != 0:
raise ValueError(f"{n=} must be divisible by {tile_n=}")
if k % tile_k != 0:
raise ValueError(f"{k=} must be divisible by {tile_k=}")
m_iters = m // tile_m
n_iters = n // tile_n
k_iters = k // tile_k
max_concurrent_steps = config.max_concurrent_steps
我们解包配置变量以便于访问,设置分块和混叠变换,使 SMEM 数据格式与 MMA 指令所期望的 匹配。
核的实现本身相对简短。第一部分使用 plgpu.emit_pipeline
设置一个 计算/内存流水线。在每一步,计算函数(do_mma
)消耗 LHS 的 (tile_m, tile_k)
切片和 RHS 的 (tile_k, tile_n)
切片。如前所述,我们指定了 transforms
,以及 delay_release=1
。最后一个参数确保传递给 do_mma
的输入窗口(a_smem
, b_smem
)在下一次调用 do_mma
完成之前至少不会被覆盖。这是必要的,因为我们只在下一步中等待 MMA 的完成,这就是为什么 arrive_barrier_slot
和 wait_barrier_slot
在每次调用时都在 0 和 1 之间切换。
def kernel(a_gmem, b_gmem, out_gmem, acc_tmem, acc_smem, consumed_barriers):
mi = lax.axis_index("m")
ni = lax.axis_index("n")
m_slice = pl.ds(mi * tile_m, tile_m)
n_slice = pl.ds(ni * tile_n, tile_n)
def do_mma(idxs, a_smem, b_smem):
(ki,) = idxs
arrive_barrier_slot = ki % 2
wait_barrier_slot = 1 - arrive_barrier_slot
plgpu.tcgen05_mma(
acc_tmem,
a_smem,
b_smem,
barrier=consumed_barriers.at[arrive_barrier_slot],
accumulate=(ki > 0),
)
plgpu.barrier_wait(consumed_barriers.at[wait_barrier_slot])
# Make sure the wait succeeds in the first iteration.
plgpu.barrier_arrive(consumed_barriers.at[1])
block_kwargs = dict(transforms=transforms, delay_release=1)
plgpu.emit_pipeline(
do_mma,
in_specs=[
plgpu.BlockSpec((tile_m, tile_k), lambda ki: (mi, ki), **block_kwargs),
plgpu.BlockSpec((tile_k, tile_n), lambda ki: (ki, ni), **block_kwargs),
],
grid=(k_iters,),
max_concurrent_steps=max_concurrent_steps,
)(a_gmem, b_gmem)
核本身以一个结尾部分结束。在做任何事情之前,我们等待流水线发出的最后一个 MMA 完成。然后,我们从 TMEM 加载最终累加器,将其写入 SMEM(记住调用 plgpu.commit_smem
),并使用 TMA 将其复制回 GMEM。
def kernel(...):
... # compute pipeline as above
final_barrier = 1 - (k_iters % 2)
plgpu.barrier_wait(consumed_barriers.at[final_barrier])
acc_smem[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(acc_smem, out_gmem.at[m_slice, n_slice])
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
剩下的就是将核实际变成一个可以用 JAX 数组调用的函数。我们为此使用了 plgpu.kernel
。网格现在只是二维的,并迭代输出块。我们分配核使用的中间缓冲区
用作累加器的 TMEM 缓冲区
在复制到 GMEM 之前用于暂存累加器的 SMEM 缓冲区
用于等待 MMA 操作完成的屏障。
def matmul0(a, b, config):
... # Setup code from the first snippet
def kernel(...):
... # The whole kernel body
f = plgpu.kernel(
kernel,
out_shape=jax.ShapeDtypeStruct((m, n), dtype),
grid=(m_iters, n_iters),
grid_names=("m", "n"),
scratch_shapes=dict(
acc_tmem=plgpu.TMEM((tile_m, tile_n), jnp.float32),
acc_smem=plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms),
consumed_barriers=plgpu.Barrier(
num_arrivals=1, num_barriers=2, orders_tensor_core=True
),
)
)
return f(a, b)
省略设置代码,这只需要 50 行!不幸的是,它目前还不是很 G,但它已经实现了 cuBLAS 利用率的一半!
1. Warp 特化#
注意
回想一下,在 Blackwell 上,单个 Pallas:MGPU 执行线程对应一个 CUDA 车道/线程的 Warpgroup。
上面的核使用单个 Warpgroup 来完成所有工作:从获取数据、发出 MMA 操作,到将结果存储到 GMEM。虽然人们可能认为 TensorCore 执行中的异步性允许我们重叠异步复制(TMA)和控制流的开销,但事实似乎并非如此。
Hopper 代 GPU 中的一个常见解决方案是利用 *Warpgroup* 特化。在 Pallas 中,plgpu.kernel
可以使用 num_threads=2
调用,这意味着网格中的每个程序将导致两次调用主体。然后,通常使用 lax.axis_index
查询线程索引,并用于选择多个角色中的一个,例如 *只* 发出异步复制或 *只* 运行 MMA 操作。
此解决方案在 Blackwell 代中也有效,但实际上更简单。由于异步复制(TMA)以及 tcgen05
MMA 指令 仅需要一个 CUDA 车道来发出它们,我们甚至不需要使用多个 *Warpgroup*。我们可以将单个 Warpgroup 分成 *四个 Warp* 并对它们进行特化!
在 Pallas 中,这可以通过 pl.core_map
和 plgpu.WarpMesh
实现。对于调用此类 core_map
的每个 Pallas 线程,将精确调用主体四次。core_map
在入口和出口处同步所有 Warp。请注意,主体中只允许标量操作。
这将是我们在此序列中对该核进行的最大重写,因此我们将再次列出完整的核源代码。
def matmul1(a, b, config: TuningConfig):
... # Setup code remains unmodified
def kernel(a_gmem, b_gmem, out_gmem,
a_smem, b_smem, acc_tmem, acc_smem,
load_barriers, consumed_barriers, mma_done_barrier):
m_index = lax.axis_index("m")
n_index = lax.axis_index("n")
m_slice = pl.ds(m_index * tile_m, tile_m)
n_slice = pl.ds(n_index * tile_n, tile_n)
@pl.core_map(plgpu.WarpMesh(axis_name="warp"))
def _per_warp():
warp_id = lax.axis_index("warp")
@pl.when(warp_id == 0)
def _memory():
def _loop_body(ki, _):
slot = lax.rem(ki, max_concurrent_steps)
@pl.when(ki >= max_concurrent_steps)
def _(): # Make sure the data has been consumed before overwriting.
plgpu.barrier_wait(consumed_barriers.at[slot])
k_slice = pl.ds(ki * tile_k, tile_k)
plgpu.copy_gmem_to_smem(
a_gmem.at[m_slice, k_slice], a_smem.at[slot], load_barriers.at[slot]
)
plgpu.copy_gmem_to_smem(
b_gmem.at[k_slice, n_slice], b_smem.at[slot], load_barriers.at[slot]
)
lax.fori_loop(0, k_iters, _loop_body, None)
@pl.when(warp_id == 1)
def _compute():
def _loop_body(ki, _):
slot = lax.rem(ki, max_concurrent_steps)
plgpu.barrier_wait(load_barriers.at[slot]) # Wait for data to arrive.
plgpu.tcgen05_mma(
acc_tmem,
a_smem.at[slot],
b_smem.at[slot],
consumed_barriers.at[slot],
accumulate=(ki > 0),
)
lax.fori_loop(0, k_iters, _loop_body, None)
plgpu.tcgen05_commit_arrive(mma_done_barrier)
plgpu.barrier_wait(mma_done_barrier)
acc_smem[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(acc_smem, out_gmem.at[m_slice, n_slice])
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
核的结构与之前完全相同:我们首先执行计算,然后是结尾部分。结尾部分保持不变(我们只使用不同的屏障来等待完成),因此我们不再讨论它。
调用 plgpu.emit_pipeline
和 do_mma
函数已被单个 pl.core_map
调用取代。您可以看到,在其主体进入后,每个 Pallas 线程(现在代表一个 Warp!)会立即找出它是四个线程中的哪一个。然后,我们使用索引为 0 的线程 *只* 发出异步复制,循环获取 MMA 操作数,而索引为 1 的线程进入另一个循环,在该循环中反复调用 plgpu.tcgen05_mma
。
这里的一个有趣方面是同步。我们维护一个 load_barriers
数组,每个数组跟踪一个未完成的 GMEM->SMEM 复制的进度。计算线程必须等待它们完成,然后才能将相应的数据提供给 MMA 操作。反之,负责异步复制的线程必须等待消耗数据的 MMA 完成,然后才能通过发出另一个异步复制来覆盖内存。这通过 consumed_barriers
进行跟踪。最后,当计算线程完成所有 MMA 操作的发出后,它调用 plgpu.tcgen05_commit_arrive(mma_done_barrier)
,请求 TensorCore 在所有 MMA 操作完成后完成 mma_done_barrier
。
现在我们可以将注意力转移到 plgpu.kernel
定义上。与前一个版本唯一的区别是,我们显式分配了两个额外的 SMEM 缓冲区来保存 MMA 操作数(以前它们是由 plgpu.emit_pipeline
隐式分配的),以及额外的屏障。请注意,load_barriers
具有 num_arrivals=2
,因为我们在同一个屏障上发出两个异步复制。orders_tensor_core
对于指定用于指示 TensorCore 操作完成的屏障是必需的。
def matmul1(a, b, config: TuningConfig):
... # Setup code remains unmodified
def kernel(...):
... # Kernel code above
f = plgpu.kernel(
kernel,
..., # Other parameters remain unchanged
scratch_shapes=dict(
a_smem=plgpu.SMEM(
(max_concurrent_steps, tile_m, tile_k), dtype, transforms=transforms
),
b_smem=plgpu.SMEM(
(max_concurrent_steps, tile_k, tile_n), dtype, transforms=transforms
),
acc_tmem=plgpu.TMEM((tile_m, tile_n), jnp.float32),
acc_smem=plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms),
load_barriers=plgpu.Barrier(
num_arrivals=2, num_barriers=max_concurrent_steps
),
consumed_barriers=plgpu.Barrier(
num_arrivals=1,
num_barriers=max_concurrent_steps,
orders_tensor_core=True,
),
mma_done_barrier=plgpu.Barrier(
num_arrivals=1, num_barriers=1, orders_tensor_core=True
),
)
)
return f(a, b)
这个相对简单的修改已经为我们带来了显著的性能提升,使我们达到了接近 cuBLAS 性能的 70%。
2. 块状结尾#
这次,我们将注意力从核的计算部分转移到其结尾部分。我们可以通过将 TMEM 到 SMEM 的复制与 SMEM 到 GMEM 的传输流水线化来提高其效率。为此,我们将 scratch_shapes
更改为分配两个较小的缓冲区,而不是一个可以容纳整个输出的 SMEM 窗口(这也减少了我们的 SMEM 使用量)
def matmul2(a, b, config):
... # Setup and kernel code
f = plgpu.kernel(
...
scratch_shapes=dict(
...
# Previously: plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms),
acc_smem=plgpu.SMEM(
(2, tile_m, config.epilogue_tile_n), dtype, transforms=transforms
),
...
)
)
然后,在核中,我们以 epilogue_tile_n
的块为单位循环遍历输出列,并逐步将输出发送到 GMEM
def matmul2(a, b, config):
... # Setup code remains unchanged
def kernel(...):
... # Compute part remains unchanged
plgpu.barrier_wait(mma_done_barrier)
out_gmem_window = out_gmem.at[m_slice, n_slice]
for ni in range(tile_n // config.epilogue_tile_n):
acc_smem_ni = acc_smem.at[ni % 2]
ni_slice = pl.ds(ni * config.epilogue_tile_n, config.epilogue_tile_n)
# Make sure that previous copy is done before we overwrite.
plgpu.wait_smem_to_gmem(1, wait_read_only=True)
acc_smem_ni[...] = plgpu.async_load_tmem(acc_tmem.at[:, ni_slice]).astype(dtype)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(acc_smem_ni, out_gmem_window.at[:, ni_slice])
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
3. 集合 (2CTA) MMA#
如果您对我们最新的核进行基准测试,您很快就会发现它无法很好地利用其计算单元,因为它们一直在等待内存提供 MMA 操作数。这意味着我们的核是内存绑定的,因为它的 *算术强度* 太低:我们为加载的每个字节执行的浮点运算次数太少。
Blackwell 架构的一个非常有效的技巧是,它允许我们将算术强度加倍,这就是 集合 MMA。核心思想很简单:我们使用两个块(在两个 SM 上)的集群来计算单个 matmul。每个块只加载每个操作数的一半,但 MMA 操作在运行时会交换来自每个块的 SMEM 数据。
我们再次从核配置更改开始
def matmul3(a, b, config):
... # Setup code
cluster_tile_m = 2 * tile_m
cluster_tile_n = 2 * tile_n
m_iters = m // cluster_tile_m
n_iters = n // cluster_tile_n
... # Setup code and kernel
f = plgpu.kernel(
...
grid=(m_iters, n_iters),
...
cluster=(2,),
cluster_names=("cluster",),
scratch_shapes=dict(
...
# Previously: plgpu.TMEM((tile_m, tile_n), jnp.float32),
acc_tmem=plgpu.TMEM(
(tile_m, cluster_tile_n), jnp.float32, collective=True
),
...
)
)
我们将 cluster
参数添加到 plgpu.kernel
,以指示我们打算让程序对协作(作为 CUDA 块集群)。我们还将 collective=True
添加到我们的 TMEM 分配中,以确保它将允许被集合 MMA 使用,并将其列数加倍(到 cluster_tile_n
)。
另一个值得注意的变化是,我们的块对最终将计算出 4 倍大的输出块,因此我们相应地缩小了网格。
我们首先更新核的入口
def kernel(...):
is_lead_block = lax.axis_index("cluster") == 0
m_index = lax.axis_index("m")
n_index = lax.axis_index("n")
m_slice = pl.ds(m_index * cluster_tile_m, cluster_tile_m)
n_slice = pl.ds(n_index * cluster_tile_n, cluster_tile_n)
这里的唯一更改是,我们使用 cluster_tile_m
和 cluster_tile_n
来计算两个块将共同计算的输出切片,并且我们还检查当前调用是否对应于集群中的第一个(领导者)块。这很重要,因为 *只有领导者块应该发出 MMA 指令*
@pl.core_map(plgpu.WarpMesh(axis_name="warp"))
def _per_warp():
warp_id = lax.axis_index("warp")
@pl.when(warp_id == 0)
def _memory():
def _loop_body(ki, _):
... # Wait for the data to be consumed, as previously.
plgpu.copy_gmem_to_smem(
..., collective_axes="cluster", partitioned_axis=0
)
plgpu.copy_gmem_to_smem(
..., collective_axes="cluster", partitioned_axis=1
)
lax.fori_loop(0, k_iters, _loop_body, None)
@pl.when(jnp.logical_and(warp_id == 1, is_lead_block))
def _compute():
def _loop_body(ki, _):
... # Wait for the data to arrive, as previously.
plgpu.tcgen05_mma(
...,
collective_axis="cluster",
)
lax.fori_loop(0, k_iters, _loop_body, None)
plgpu.tcgen05_commit_arrive(mma_done_barrier, collective_axis="cluster")
您会看到一些修改。首先,两个块都必须发出异步复制。在两个块中,我们请求复制整个集群的完整窗口,但添加 collective_axes="cluster"
表明加载由两个块共同执行。partitioned_axis=
指定了操作数的哪个轴将在集群之间分割。我们分割了 LHS 的行和 RHS 的列。
警告
分区的集合复制只在集群的领导块中完成传递给 copy_gmem_to_smem
的屏障!这就是为什么您会看到核在第二个块中从不等待加载。
其次,如前所述,我们还通过谓词化 _compute
主体,以便只有领导块运行 MMA 指令。所有 tcgen05
调用都另外获得一个 collective_axis=
参数,以指示 MMA 的完成应该完成集群中两个块的屏障。
最后,我们对结尾部分进行了一次小修改。即使集群中的两个块共同计算出形状为 (cluster_tile_m, cluster_tile_n)
的结果,每个单独的块也只持有形状为 (tile_m, cluster_tile_n)
的结果。我们更改了输出切片代码,使其需要切片出正确的 out_gmem_window
def matmul3(a, b, config):
...
def kernel(...):
... # Compute
plgpu.barrier_wait(mma_done_barrier)
out_m_index = m_index * 2 + lax.axis_index("cluster")
out_m_slice = pl.ds(out_m_index * tile_m, tile_m)
out_gmem_window = out_gmem.at[out_m_slice, n_slice]
for ni in range(cluster_tile_n // config.epilogue_tile_n):
...
...
4. 持久核#
我们的下一步是使核持久化。这意味着我们只会启动 GPU 上实际可以并行运行的集群数量(SM 数量除以 2),并且每个集群将循环处理固定数量的输出块。此技术允许我们更好地分摊块(反)初始化成本(因为它们只在每个 SM 上执行一次),并实现 SMEM 到 GMEM 的复制与下一个输出块的计算之间的一定程度的重叠。
def matmul4(a, b, config):
...
num_sms = jax.extend.backend.get_default_device().core_count
f = plgpu.kernel(
...
grid=(num_sms // 2,),
grid_names=("cluster_grid",),
...
)
更改相对简单。我们利用 plgpu.nd_loop
辅助函数来指定我们的迭代空间为 (m_iters, n_iters)
,但我们也要求它应该使用 collective_axes=
参数跨集群网格进行分割。
def matmul4(a, b, config):
...
def kernel(...):
is_lead_block = lax.axis_index("cluster") == 0
@plgpu.nd_loop((m_iters, n_iters), collective_axes="cluster_grid")
def _mn_loop(loop_info: plgpu.NDLoopInfo):
m_index, n_index = loop_info.index
m_slice = ...
n_slice = ...
... # Compute + epilogue
核主体计算部分的唯一有意义的修改是确保内存 Warp 中的前几次对 consumed_barriers
的等待仅在处理第一个输出块时跳过(如 loop_info.local_index == 0
所示)。在处理第二个(或更晚)块时,SMEM 缓冲区用于计算前一个输出块,因此我们需要确保在覆盖它们之前,这些计算已经完成。
def matmul4(a, b, config):
...
def kernel(...):
...
def _mn_loop(...):
...
@pl.core_map(plgpu.WarpMesh(axis_name="warp"))
def _per_warp():
warp_id = lax.axis_index("warp")
@pl.when(warp_id == 0)
def _memory():
def _loop_body(ki, _):
slot = lax.rem(ki, max_concurrent_steps)
@pl.when(jnp.logical_or(ki >= max_concurrent_steps, loop_info.local_index > 0))
def _(): # Make sure the data has been consumed before overwriting.
plgpu.barrier_wait(consumed_barriers.at[slot])
最后,我们通过添加一行来修改核的结尾部分
def matmul4(a, b, config):
...
def kernel(...):
...
def _mn_loop(...):
... # Compute + epilogue
plgpu.wait_load_tmem() # Load must complete before MMA can overwrite TMEM.
正如注释所示,由于 TMEM 加载是异步的,在移动到下一个输出块并发出另一个 MMA 来覆盖我们的 TMEM 分配之前,我们必须等待它们完成。
5. 专用结尾 Warpgroup#
虽然持久化本身很有用,但它也解锁了另一项优化。当核中的单个 Pallas 线程完成核的计算部分时,它会执行整个结尾部分。然而,这意味着在完成之前,它无法再为 TensorCore 发出任何工作!
这引出了一个简单的解决方案:使用 2 个 Pallas 线程(Warpgroup)!第一个将只专注于获取 MMA 操作数并发出 MMA 操作,而第二个将只执行结尾部分!当然,为了让它们能够并行运行,我们需要双缓冲用于累加器的 TMEM,并使用额外的屏障进行同步。
def matmul5(a, b, config):
...
f = plgpu.kernel(
...,
num_threads=2,
thread_name="wg",
scratch_shapes=dict(
...
# Previously: plgpu.TMEM((tile_m, cluster_tile_n), jnp.float32, collective=True),
acc_tmem=plgpu.TMEM(
(tile_m, 2 * cluster_tile_n), jnp.float32, collective=True
),
...
# mma_done_barrier (now 2 barriers) + a new store_done_barrier (also 2 barriers)
# Previously: plgpu.Barrier(num_arrivals=1, num_barriers=1, orders_tensor_core=True),
mma_done_barrier=plgpu.Barrier(
num_arrivals=1, num_barriers=2, orders_tensor_core=True
),
store_done_barrier=plgpu.ClusterBarrier(
collective_axes=("cluster",),
num_arrivals=1,
num_barriers=2,
orders_tensor_core=True,
),
),
)
核的开头与我们之前的核类似。我们将 acc_tmem
重命名为 acc_tmem_slots
,并在循环遍历输出块时在它们的两个半部分之间切换。
def matmul(a, b, config):
...
def kernel(a_gmem, b_gmem, out_gmem,
a_smem, b_smem, acc_tmem_slots, acc_smem,
load_barriers, consumed_barriers, mma_done_barrier, store_done_barrier):
wg_idx = lax.axis_index("wg")
is_lead_block = ...
@plgpu.nd_loop(...)
def _mn_loop(...):
...
acc_slot = lax.rem(loop_info.local_index, jnp.int32(2))
acc_tmem = acc_tmem_slots.at[:, pl.ds(acc_slot * cluster_tile_n, cluster_tile_n)]
...
计算部分另外基于 wg_idx == 0
进行谓词化。在使用屏障的方式上也有两个重要的更改。首先,如果我们想重用我们的 TMEM 分配来进行 MMA(这只发生在 loop_info.local_index >= 2
时),我们需要等待我们想要重用的 TMEM 部分的 store_done_barrier
(如 acc_slot
所示)。其次,一旦我们想请求 TensorCore 在完成时到达 mma_done_barrier
,我们再次需要选择当前使用的 TMEM 部分对应的两个屏障之一。
警告
请注意,即使集群中只有一个块发出 MMA,它们都会等待 store_done_barrier
。这只是必要的,因为在没有 wait
的情况下两次到达同一个屏障有时会导致硬件断言。
def matmul(a, b, config):
...
def kernel(...):
...
def _mn_loop(...):
acc_slot = ...
acc_tmem = ...
@pl.when(wg_idx == 0)
def _compute_wg():
@pl.core_map(plgpu.WarpMesh(axis_name="warp"))
def _per_warp():
warp_id = lax.axis_index("warp")
@pl.when(warp_id == 0)
def _memory():
... # Memory code remains unchanged
# Wait for store to complete (except for the first two steps).
@pl.when(jnp.logical_and(warp_id == 1, loop_info.local_index >= 2))
def _wait_store():
plgpu.barrier_wait(store_done_barrier.at[acc_slot])
@pl.when(jnp.logical_and(warp_id == 1, is_lead_block))
def _compute():
... # Compute loop remains unchanged
plgpu.tcgen05_commit_arrive(mma_done_barrier.at[acc_slot], collective_axis="cluster")
最后,我们修改了结尾部分,只让第二个 Warpgroup 执行它,并让 Warpgroup 通过到达与其使用的 TMEM 部分关联的 store_done_barrier
来发出存储完成信号。
def matmul(a, b, config):
...
def kernel(...):
...
def _mn_loop(...):
... # Compute
@pl.when(wg_idx == 1)
def _store_wg():
... # Unmodified epilogue
plgpu.wait_load_tmem() # Load must complete before we signal.
plgpu.barrier_arrive(store_done_barrier.at[acc_slot])
6. 网格分块#
我们对该核的最后一个更改是改变我们生成输出块的顺序,以更好地利用 L2。如前所述,计算单元与内存系统相比速度极快,因此我们可以获得任何帮助来尝试让它们保持忙碌。
注意
这个技巧有很多不同的名称。CUTLASS 称之为“光栅化顺序”,ThunderKittens 称之为“超级分组”,而 Triton 教程称之为“程序重排”。我们使用“网格分块”这个名称。
我们的策略受到 CUTLASS 的启发,工作方式如下。首先,您选择迭代空间中的哪个维度是变化最快的(我们称之为 grid_minor_dim
)。然后,您选择该维度上的块大小(grid_tile_width
)。我们不是在增加主索引之前遍历网格的整个次要维度,而是在每次遍历 grid_tile_width
个元素时执行此操作。一旦我们用完了元素,我们就进入下一个块。但有一个转折!我们不是跳转到第二个块的开头,而是从末尾开始,向后移动。这确保了在我们切换块时,我们可以重用其中一个操作数的最近块。
由于这种策略非常普遍,我们提供了一个辅助函数:plgpu.planar_snake
。使用此辅助函数时,对核的更改非常微不足道
def matmul(a, b, config):
...
def kernel(...):
...
# We now only iterate over a 1D loop (but we still split it across clusters).
@plgpu.nd_loop((m_iters * n_iters,), collective_axes="cluster_grid")
def _mn_loop(loop_info: plgpu.NDLoopInfo):
(lin_idx,) = loop_info.index
m_index, n_index = plgpu.planar_snake(
lin_idx, # Linear index.
(m_iters, n_iters), # The 2D iteration space.
config.grid_minor_dim, # 0 or 1, indicates the fastest changing dim.
config.grid_tile_width, # The width of tiles along the fastest changing dim.
)
... # Rest of the code remains unmodified
这个简单的技巧 *效果惊人*,是实现最先进性能的关键。
最终核#
恭喜您完成了本教程!在前面的部分,我们只关注了不同核之间的差异,很少列出完整的源代码。这在扩展实现时隐藏不相关细节很有用,但看到完整的源代码也可能很有帮助。所以,这是!整个实现不到 150 行,并且达到了最先进的性能(至少在我们基准测试使用的形状上)。
def matmul6(a, b, config: TuningConfig):
dtype = a.dtype
m, k = a.shape
_, n = b.shape
tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k
swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8)
swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
transforms = (
plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle)
)
if m % tile_m != 0:
raise ValueError(f"{m=} must be divisible by {tile_m=}")
if n % tile_n != 0:
raise ValueError(f"{n=} must be divisible by {tile_n=}")
if k % tile_k != 0:
raise ValueError(f"{k=} must be divisible by {tile_k=}")
cluster_tile_m = 2 * tile_m
cluster_tile_n = 2 * tile_n
m_iters = m // cluster_tile_m
n_iters = n // cluster_tile_n
k_iters = k // tile_k
max_concurrent_steps = config.max_concurrent_steps
def kernel(a_gmem, b_gmem, out_gmem,
a_smem, b_smem, acc_tmem, acc_smem,
load_barriers, consumed_barriers, mma_done_barrier, store_done_barrier):
wg_idx = lax.axis_index("wg")
is_lead_block = lax.axis_index("cluster") == 0
@plgpu.nd_loop((m_iters * n_iters,), collective_axes="cluster_grid")
def _mn_loop(loop_info: plgpu.NDLoopInfo):
(lin_idx,) = loop_info.index
m_index, n_index = plgpu.planar_snake(
lin_idx,
(m_iters, n_iters),
config.grid_minor_dim,
config.grid_tile_width,
)
m_slice = pl.ds(m_index * cluster_tile_m, cluster_tile_m)
n_slice = pl.ds(n_index * cluster_tile_n, cluster_tile_n)
acc_slot = lax.rem(loop_info.local_index, jnp.int32(2))
mn_acc_tmem = acc_tmem.at[:, pl.ds(acc_slot * cluster_tile_n, cluster_tile_n)]
@pl.when(wg_idx == 0)
def _compute_wg():
@pl.core_map(plgpu.WarpMesh(axis_name="warp"))
def _per_warp():
warp_id = lax.axis_index("warp")
@pl.when(warp_id == 0)
def _memory():
def _loop_body(ki, _):
slot = lax.rem(ki, max_concurrent_steps)
@pl.when(jnp.logical_or(ki >= max_concurrent_steps, loop_info.local_index > 0))
def _(): # Make sure the data has been consumed before overwriting.
plgpu.barrier_wait(consumed_barriers.at[slot])
k_slice = pl.ds(ki * tile_k, tile_k)
plgpu.copy_gmem_to_smem(
a_gmem.at[m_slice, k_slice], a_smem.at[slot], load_barriers.at[slot],
collective_axes="cluster", partitioned_axis=0
)
plgpu.copy_gmem_to_smem(
b_gmem.at[k_slice, n_slice], b_smem.at[slot], load_barriers.at[slot],
collective_axes="cluster", partitioned_axis=1
)
lax.fori_loop(0, k_iters, _loop_body, None)
# Wait for store to complete (except for the first two steps).
@pl.when(jnp.logical_and(warp_id == 1, loop_info.local_index >= 2))
def _wait_store():
plgpu.barrier_wait(store_done_barrier.at[acc_slot])
@pl.when(jnp.logical_and(warp_id == 1, is_lead_block))
def _compute():
def _loop_body(ki, _):
slot = lax.rem(ki, max_concurrent_steps)
plgpu.barrier_wait(load_barriers.at[slot]) # Wait for data to arrive.
plgpu.tcgen05_mma(
mn_acc_tmem,
a_smem.at[slot],
b_smem.at[slot],
consumed_barriers.at[slot],
accumulate=(ki > 0),
collective_axis="cluster",
)
lax.fori_loop(0, k_iters, _loop_body, None)
plgpu.tcgen05_commit_arrive(
mma_done_barrier.at[acc_slot],
collective_axis="cluster",
)
@pl.when(wg_idx == 1)
def _store_wg():
# Ensure that copies from the previous mn step have completed.
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
plgpu.barrier_wait(mma_done_barrier.at[acc_slot])
out_m_index = m_index * 2 + lax.axis_index("cluster")
out_m_slice = pl.ds(out_m_index * tile_m, tile_m)
out_gmem_window = out_gmem.at[out_m_slice, n_slice]
for ni in range(cluster_tile_n // config.epilogue_tile_n):
acc_smem_ni = acc_smem.at[ni % 2]
ni_slice = pl.ds(ni * config.epilogue_tile_n, config.epilogue_tile_n)
# Make sure that previous copy is done before we overwrite.
plgpu.wait_smem_to_gmem(1, wait_read_only=True)
acc_smem_ni[...] = plgpu.async_load_tmem(mn_acc_tmem.at[:, ni_slice]).astype(dtype)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(acc_smem_ni, out_gmem_window.at[:, ni_slice])
plgpu.wait_load_tmem() # Load must complete before we signal.
plgpu.barrier_arrive(store_done_barrier.at[acc_slot])
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
num_sms = backend.get_default_device().core_count
f = plgpu.kernel(
kernel,
out_shape=jax.ShapeDtypeStruct((m, n), dtype),
grid=(num_sms // 2,),
grid_names=("cluster_grid",),
cluster=(2,),
cluster_names=("cluster",),
num_threads=2,
thread_name="wg",
scratch_shapes=dict(
a_smem=plgpu.SMEM(
(max_concurrent_steps, tile_m, tile_k), dtype, transforms=transforms
),
b_smem=plgpu.SMEM(
(max_concurrent_steps, tile_k, tile_n), dtype, transforms=transforms
),
acc_tmem=plgpu.TMEM(
(tile_m, 2 * cluster_tile_n), jnp.float32, collective=True
),
acc_smem=plgpu.SMEM(
(2, tile_m, config.epilogue_tile_n), dtype, transforms=transforms
),
load_barriers=plgpu.Barrier(
num_arrivals=2, num_barriers=max_concurrent_steps
),
consumed_barriers=plgpu.Barrier(
num_arrivals=1,
num_barriers=max_concurrent_steps,
orders_tensor_core=True,
),
mma_done_barrier=plgpu.Barrier(
num_arrivals=1, num_barriers=2, orders_tensor_core=True
),
store_done_barrier=plgpu.ClusterBarrier(
collective_axes=("cluster",),
num_arrivals=1,
num_barriers=2,
orders_tensor_core=True,
),
)
)
return f(a, b)