在 Blackwell 上编写高性能矩阵乘法核#

在本指南中,我们将逐步迭代矩阵乘法核。第一个实现将非常简单,但也非常慢。然而,只需几个简单的步骤,就可以将其修改为最先进的核,其性能可与 cuBLAS 和 CUTLASS 等高度优化的实现相媲美甚至超越。

警告

下表中显示的利用率可能与您在线看到的有所不同,但这些差异很可能可以通过不同的输入数据分布来解释。我们这里的所有基准测试都使用具有 iid 正常 float16 条目的数组,事实证明这是您可以选择的较慢的分布之一。通过更改 BENCHMARK 变量为 True,运行 我们的测试文件,您可以自己重现这些数字。

简而言之:如果不指定输入数据分布,请勿相信 matmul 基准测试。

实现

TensorCore 利用率

cuBLAS 利用率百分比

0. 基本核

37.62%

59.4%

1. Warp 特化

45.47%

71.7%

2. 块状结尾

55.82%

88.1%

3. 集合 (2CTA) MMA

59.41%

93.7%

4. 持久核

61.46%

97.0%

5. 专用结尾 Warpgroup

63.38%

100.0%

6. 网格分块

69.44%

109.6%

cuBLAS

63.38%

100.0%

CUTLASS

69.30%

109.3%

cuBLAS 基线是通过测量 jax.dot 的性能获得的。CUTLASS 性能是通过从以下 cutlass_profiler 调用中获取最佳结果来测量的(不包括稀疏 matmuls)

cutlass_profiler --dist=gaussian,mean:0,stddev:1,scale:-1 --output=results.csv --accumulator-type=f32 --m=4096 --k=4096 --n=8192 --kernels='*sm100*' --A=f16 --B=f16 --C=void --D=f16

在每一步,我们将展示核的完整实现,或者展示上一步和当前步骤中代码列表之间的差异。完整实现可以在 我们的测试文件 中找到。您还可以在 Pallas ops 包中找到一个完整的独立优化核实现:Pallas ops 包

0. 基本核#

我们从一个简单的单 CTA(块)单 Warpgroup 示例开始。为了方便起见,我们将核的调优参数分离到一个单独的类中

@dataclasses.dataclass(frozen=True)
class TuningConfig:
  tile_m: int
  tile_n: int
  tile_k: int
  max_concurrent_steps: int

tile_mtile_ntile_k 指定了流水线每一步执行的 matmul 的大小。通常,tile_k 理想情况下应等于 128 除以输入元素类型的字节宽度。max_concurrent_steps 指定了计算/内存流水线中的内存预取深度,这在其他实现中经常被称为阶段数。

核的实现首先是一些设置代码

def matmul0(a, b, config: TuningConfig):
  dtype = a.dtype
  m, k = a.shape
  _, n = b.shape
  tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k
  swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8)
  swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
  transforms = (
      plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle)
  )
  if m % tile_m != 0:
    raise ValueError(f"{m=} must be divisible by {tile_m=}")
  if n % tile_n != 0:
    raise ValueError(f"{n=} must be divisible by {tile_n=}")
  if k % tile_k != 0:
    raise ValueError(f"{k=} must be divisible by {tile_k=}")
  m_iters = m // tile_m
  n_iters = n // tile_n
  k_iters = k // tile_k
  max_concurrent_steps = config.max_concurrent_steps

我们解包配置变量以便于访问,设置分块和混叠变换,使 SMEM 数据格式与 MMA 指令所期望的 匹配。

核的实现本身相对简短。第一部分使用 plgpu.emit_pipeline 设置一个 计算/内存流水线。在每一步,计算函数(do_mma)消耗 LHS 的 (tile_m, tile_k) 切片和 RHS 的 (tile_k, tile_n) 切片。如前所述,我们指定了 transforms,以及 delay_release=1。最后一个参数确保传递给 do_mma 的输入窗口(a_smem, b_smem)在下一次调用 do_mma 完成之前至少不会被覆盖。这是必要的,因为我们只在下一步中等待 MMA 的完成,这就是为什么 arrive_barrier_slotwait_barrier_slot 在每次调用时都在 0 和 1 之间切换。

  def kernel(a_gmem, b_gmem, out_gmem, acc_tmem, acc_smem, consumed_barriers):
    mi = lax.axis_index("m")
    ni = lax.axis_index("n")
    m_slice = pl.ds(mi * tile_m, tile_m)
    n_slice = pl.ds(ni * tile_n, tile_n)

    def do_mma(idxs, a_smem, b_smem):
      (ki,) = idxs
      arrive_barrier_slot = ki % 2
      wait_barrier_slot = 1 - arrive_barrier_slot
      plgpu.tcgen05_mma(
          acc_tmem,
          a_smem,
          b_smem,
          barrier=consumed_barriers.at[arrive_barrier_slot],
          accumulate=(ki > 0),
      )
      plgpu.barrier_wait(consumed_barriers.at[wait_barrier_slot])

    # Make sure the wait succeeds in the first iteration.
    plgpu.barrier_arrive(consumed_barriers.at[1])
    block_kwargs = dict(transforms=transforms, delay_release=1)
    plgpu.emit_pipeline(
      do_mma,
      in_specs=[
          plgpu.BlockSpec((tile_m, tile_k), lambda ki: (mi, ki), **block_kwargs),
          plgpu.BlockSpec((tile_k, tile_n), lambda ki: (ki, ni), **block_kwargs),
      ],
      grid=(k_iters,),
      max_concurrent_steps=max_concurrent_steps,
    )(a_gmem, b_gmem)

核本身以一个结尾部分结束。在做任何事情之前,我们等待流水线发出的最后一个 MMA 完成。然后,我们从 TMEM 加载最终累加器,将其写入 SMEM(记住调用 plgpu.commit_smem),并使用 TMA 将其复制回 GMEM。

  def kernel(...):
    ...  # compute pipeline as above
    final_barrier = 1 - (k_iters % 2)
    plgpu.barrier_wait(consumed_barriers.at[final_barrier])
    acc_smem[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype)
    plgpu.commit_smem()
    plgpu.copy_smem_to_gmem(acc_smem, out_gmem.at[m_slice, n_slice])
    plgpu.wait_smem_to_gmem(0, wait_read_only=True)

剩下的就是将核实际变成一个可以用 JAX 数组调用的函数。我们为此使用了 plgpu.kernel。网格现在只是二维的,并迭代输出块。我们分配核使用的中间缓冲区

  1. 用作累加器的 TMEM 缓冲区

  2. 在复制到 GMEM 之前用于暂存累加器的 SMEM 缓冲区

  3. 用于等待 MMA 操作完成的屏障。

def matmul0(a, b, config):
  ... # Setup code from the first snippet
  def kernel(...):
    ... # The whole kernel body

  f = plgpu.kernel(
      kernel,
      out_shape=jax.ShapeDtypeStruct((m, n), dtype),
      grid=(m_iters, n_iters),
      grid_names=("m", "n"),
      scratch_shapes=dict(
        acc_tmem=plgpu.TMEM((tile_m, tile_n), jnp.float32),
        acc_smem=plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms),
        consumed_barriers=plgpu.Barrier(
          num_arrivals=1, num_barriers=2, orders_tensor_core=True
        ),
      )
  )
  return f(a, b)

省略设置代码,这只需要 50 行!不幸的是,它目前还不是很 G,但它已经实现了 cuBLAS 利用率的一半!

1. Warp 特化#

注意

回想一下,在 Blackwell 上,单个 Pallas:MGPU 执行线程对应一个 CUDA 车道/线程的 Warpgroup。

上面的核使用单个 Warpgroup 来完成所有工作:从获取数据、发出 MMA 操作,到将结果存储到 GMEM。虽然人们可能认为 TensorCore 执行中的异步性允许我们重叠异步复制(TMA)和控制流的开销,但事实似乎并非如此。

Hopper 代 GPU 中的一个常见解决方案是利用 *Warpgroup* 特化。在 Pallas 中,plgpu.kernel 可以使用 num_threads=2 调用,这意味着网格中的每个程序将导致两次调用主体。然后,通常使用 lax.axis_index 查询线程索引,并用于选择多个角色中的一个,例如 *只* 发出异步复制或 *只* 运行 MMA 操作。

此解决方案在 Blackwell 代中也有效,但实际上更简单。由于异步复制(TMA)以及 tcgen05 MMA 指令 仅需要一个 CUDA 车道来发出它们,我们甚至不需要使用多个 *Warpgroup*。我们可以将单个 Warpgroup 分成 *四个 Warp* 并对它们进行特化!

在 Pallas 中,这可以通过 pl.core_mapplgpu.WarpMesh 实现。对于调用此类 core_map 的每个 Pallas 线程,将精确调用主体四次。core_map 在入口和出口处同步所有 Warp。请注意,主体中只允许标量操作。

这将是我们在此序列中对该核进行的最大重写,因此我们将再次列出完整的核源代码。

def matmul1(a, b, config: TuningConfig):
  ... # Setup code remains unmodified

  def kernel(a_gmem, b_gmem, out_gmem,
             a_smem, b_smem, acc_tmem, acc_smem,
             load_barriers, consumed_barriers, mma_done_barrier):
    m_index = lax.axis_index("m")
    n_index = lax.axis_index("n")
    m_slice = pl.ds(m_index * tile_m, tile_m)
    n_slice = pl.ds(n_index * tile_n, tile_n)

    @pl.core_map(plgpu.WarpMesh(axis_name="warp"))
    def _per_warp():
      warp_id = lax.axis_index("warp")

      @pl.when(warp_id == 0)
      def _memory():
        def _loop_body(ki, _):
          slot = lax.rem(ki, max_concurrent_steps)
          @pl.when(ki >= max_concurrent_steps)
          def _():  # Make sure the data has been consumed before overwriting.
            plgpu.barrier_wait(consumed_barriers.at[slot])
          k_slice = pl.ds(ki * tile_k, tile_k)
          plgpu.copy_gmem_to_smem(
              a_gmem.at[m_slice, k_slice], a_smem.at[slot], load_barriers.at[slot]
          )
          plgpu.copy_gmem_to_smem(
              b_gmem.at[k_slice, n_slice], b_smem.at[slot], load_barriers.at[slot]
          )

        lax.fori_loop(0, k_iters, _loop_body, None)

      @pl.when(warp_id == 1)
      def _compute():
        def _loop_body(ki, _):
          slot = lax.rem(ki, max_concurrent_steps)
          plgpu.barrier_wait(load_barriers.at[slot])  # Wait for data to arrive.
          plgpu.tcgen05_mma(
              acc_tmem,
              a_smem.at[slot],
              b_smem.at[slot],
              consumed_barriers.at[slot],
              accumulate=(ki > 0),
          )
        lax.fori_loop(0, k_iters, _loop_body, None)
        plgpu.tcgen05_commit_arrive(mma_done_barrier)

    plgpu.barrier_wait(mma_done_barrier)
    acc_smem[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype)
    plgpu.commit_smem()
    plgpu.copy_smem_to_gmem(acc_smem, out_gmem.at[m_slice, n_slice])
    plgpu.wait_smem_to_gmem(0, wait_read_only=True)

核的结构与之前完全相同:我们首先执行计算,然后是结尾部分。结尾部分保持不变(我们只使用不同的屏障来等待完成),因此我们不再讨论它。

调用 plgpu.emit_pipelinedo_mma 函数已被单个 pl.core_map 调用取代。您可以看到,在其主体进入后,每个 Pallas 线程(现在代表一个 Warp!)会立即找出它是四个线程中的哪一个。然后,我们使用索引为 0 的线程 *只* 发出异步复制,循环获取 MMA 操作数,而索引为 1 的线程进入另一个循环,在该循环中反复调用 plgpu.tcgen05_mma

这里的一个有趣方面是同步。我们维护一个 load_barriers 数组,每个数组跟踪一个未完成的 GMEM->SMEM 复制的进度。计算线程必须等待它们完成,然后才能将相应的数据提供给 MMA 操作。反之,负责异步复制的线程必须等待消耗数据的 MMA 完成,然后才能通过发出另一个异步复制来覆盖内存。这通过 consumed_barriers 进行跟踪。最后,当计算线程完成所有 MMA 操作的发出后,它调用 plgpu.tcgen05_commit_arrive(mma_done_barrier),请求 TensorCore 在所有 MMA 操作完成后完成 mma_done_barrier

现在我们可以将注意力转移到 plgpu.kernel 定义上。与前一个版本唯一的区别是,我们显式分配了两个额外的 SMEM 缓冲区来保存 MMA 操作数(以前它们是由 plgpu.emit_pipeline 隐式分配的),以及额外的屏障。请注意,load_barriers 具有 num_arrivals=2,因为我们在同一个屏障上发出两个异步复制。orders_tensor_core 对于指定用于指示 TensorCore 操作完成的屏障是必需的。

def matmul1(a, b, config: TuningConfig):
  ... # Setup code remains unmodified

  def kernel(...):
    ... # Kernel code above

  f = plgpu.kernel(
      kernel,
      ...,  # Other parameters remain unchanged
      scratch_shapes=dict(
        a_smem=plgpu.SMEM(
            (max_concurrent_steps, tile_m, tile_k), dtype, transforms=transforms
        ),
        b_smem=plgpu.SMEM(
            (max_concurrent_steps, tile_k, tile_n), dtype, transforms=transforms
        ),
        acc_tmem=plgpu.TMEM((tile_m, tile_n), jnp.float32),
        acc_smem=plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms),
        load_barriers=plgpu.Barrier(
            num_arrivals=2, num_barriers=max_concurrent_steps
        ),
        consumed_barriers=plgpu.Barrier(
            num_arrivals=1,
            num_barriers=max_concurrent_steps,
            orders_tensor_core=True,
        ),
        mma_done_barrier=plgpu.Barrier(
            num_arrivals=1, num_barriers=1, orders_tensor_core=True
        ),
      )
  )
  return f(a, b)

这个相对简单的修改已经为我们带来了显著的性能提升,使我们达到了接近 cuBLAS 性能的 70%。

2. 块状结尾#

这次,我们将注意力从核的计算部分转移到其结尾部分。我们可以通过将 TMEM 到 SMEM 的复制与 SMEM 到 GMEM 的传输流水线化来提高其效率。为此,我们将 scratch_shapes 更改为分配两个较小的缓冲区,而不是一个可以容纳整个输出的 SMEM 窗口(这也减少了我们的 SMEM 使用量)

def matmul2(a, b, config):
  ... # Setup and kernel code
  f = plgpu.kernel(
      ...
      scratch_shapes=dict(
        ...
        # Previously: plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms),
        acc_smem=plgpu.SMEM(
            (2, tile_m, config.epilogue_tile_n), dtype, transforms=transforms
        ),
        ...
      )
  )

然后,在核中,我们以 epilogue_tile_n 的块为单位循环遍历输出列,并逐步将输出发送到 GMEM

def matmul2(a, b, config):
  ... # Setup code remains unchanged

  def kernel(...):
    ... # Compute part remains unchanged

    plgpu.barrier_wait(mma_done_barrier)
    out_gmem_window = out_gmem.at[m_slice, n_slice]
    for ni in range(tile_n // config.epilogue_tile_n):
      acc_smem_ni = acc_smem.at[ni % 2]
      ni_slice = pl.ds(ni * config.epilogue_tile_n, config.epilogue_tile_n)
      # Make sure that previous copy is done before we overwrite.
      plgpu.wait_smem_to_gmem(1, wait_read_only=True)
      acc_smem_ni[...] = plgpu.async_load_tmem(acc_tmem.at[:, ni_slice]).astype(dtype)
      plgpu.commit_smem()
      plgpu.copy_smem_to_gmem(acc_smem_ni, out_gmem_window.at[:, ni_slice])
    plgpu.wait_smem_to_gmem(0, wait_read_only=True)

3. 集合 (2CTA) MMA#

如果您对我们最新的核进行基准测试,您很快就会发现它无法很好地利用其计算单元,因为它们一直在等待内存提供 MMA 操作数。这意味着我们的核是内存绑定的,因为它的 *算术强度* 太低:我们为加载的每个字节执行的浮点运算次数太少。

Blackwell 架构的一个非常有效的技巧是,它允许我们将算术强度加倍,这就是 集合 MMA。核心思想很简单:我们使用两个块(在两个 SM 上)的集群来计算单个 matmul。每个块只加载每个操作数的一半,但 MMA 操作在运行时会交换来自每个块的 SMEM 数据。

我们再次从核配置更改开始

def matmul3(a, b, config):
  ...  # Setup code
  cluster_tile_m = 2 * tile_m
  cluster_tile_n = 2 * tile_n
  m_iters = m // cluster_tile_m
  n_iters = n // cluster_tile_n
  ... # Setup code and kernel

  f = plgpu.kernel(
      ...
      grid=(m_iters, n_iters),
      ...
      cluster=(2,),
      cluster_names=("cluster",),
      scratch_shapes=dict(
          ...
          # Previously: plgpu.TMEM((tile_m, tile_n), jnp.float32),
          acc_tmem=plgpu.TMEM(
              (tile_m, cluster_tile_n), jnp.float32, collective=True
          ),
          ...
      )
  )

我们将 cluster 参数添加到 plgpu.kernel,以指示我们打算让程序对协作(作为 CUDA 块集群)。我们还将 collective=True 添加到我们的 TMEM 分配中,以确保它将允许被集合 MMA 使用,并将其列数加倍(到 cluster_tile_n)。

另一个值得注意的变化是,我们的块对最终将计算出 4 倍大的输出块,因此我们相应地缩小了网格。

我们首先更新核的入口

  def kernel(...):
    is_lead_block = lax.axis_index("cluster") == 0
    m_index = lax.axis_index("m")
    n_index = lax.axis_index("n")
    m_slice = pl.ds(m_index * cluster_tile_m, cluster_tile_m)
    n_slice = pl.ds(n_index * cluster_tile_n, cluster_tile_n)

这里的唯一更改是,我们使用 cluster_tile_mcluster_tile_n 来计算两个块将共同计算的输出切片,并且我们还检查当前调用是否对应于集群中的第一个(领导者)块。这很重要,因为 *只有领导者块应该发出 MMA 指令*

    @pl.core_map(plgpu.WarpMesh(axis_name="warp"))
    def _per_warp():
      warp_id = lax.axis_index("warp")

      @pl.when(warp_id == 0)
      def _memory():
        def _loop_body(ki, _):
          ...  # Wait for the data to be consumed, as previously.
          plgpu.copy_gmem_to_smem(
              ..., collective_axes="cluster", partitioned_axis=0
          )
          plgpu.copy_gmem_to_smem(
              ..., collective_axes="cluster", partitioned_axis=1
          )
        lax.fori_loop(0, k_iters, _loop_body, None)

      @pl.when(jnp.logical_and(warp_id == 1, is_lead_block))
      def _compute():
        def _loop_body(ki, _):
          ...  # Wait for the data to arrive, as previously.
          plgpu.tcgen05_mma(
              ...,
              collective_axis="cluster",
          )
        lax.fori_loop(0, k_iters, _loop_body, None)
        plgpu.tcgen05_commit_arrive(mma_done_barrier, collective_axis="cluster")

您会看到一些修改。首先,两个块都必须发出异步复制。在两个块中,我们请求复制整个集群的完整窗口,但添加 collective_axes="cluster" 表明加载由两个块共同执行。partitioned_axis= 指定了操作数的哪个轴将在集群之间分割。我们分割了 LHS 的行和 RHS 的列。

警告

分区的集合复制只在集群的领导块中完成传递给 copy_gmem_to_smem 的屏障!这就是为什么您会看到核在第二个块中从不等待加载。

其次,如前所述,我们还通过谓词化 _compute 主体,以便只有领导块运行 MMA 指令。所有 tcgen05 调用都另外获得一个 collective_axis= 参数,以指示 MMA 的完成应该完成集群中两个块的屏障。

最后,我们对结尾部分进行了一次小修改。即使集群中的两个块共同计算出形状为 (cluster_tile_m, cluster_tile_n) 的结果,每个单独的块也只持有形状为 (tile_m, cluster_tile_n) 的结果。我们更改了输出切片代码,使其需要切片出正确的 out_gmem_window

def matmul3(a, b, config):
  ...
  def kernel(...):
    ... # Compute

    plgpu.barrier_wait(mma_done_barrier)
    out_m_index = m_index * 2 + lax.axis_index("cluster")
    out_m_slice = pl.ds(out_m_index * tile_m, tile_m)
    out_gmem_window = out_gmem.at[out_m_slice, n_slice]
    for ni in range(cluster_tile_n // config.epilogue_tile_n):
      ...

  ...

4. 持久核#

我们的下一步是使核持久化。这意味着我们只会启动 GPU 上实际可以并行运行的集群数量(SM 数量除以 2),并且每个集群将循环处理固定数量的输出块。此技术允许我们更好地分摊块(反)初始化成本(因为它们只在每个 SM 上执行一次),并实现 SMEM 到 GMEM 的复制与下一个输出块的计算之间的一定程度的重叠。

def matmul4(a, b, config):
  ...

  num_sms = jax.extend.backend.get_default_device().core_count
  f = plgpu.kernel(
      ...
      grid=(num_sms // 2,),
      grid_names=("cluster_grid",),
      ...
  )

更改相对简单。我们利用 plgpu.nd_loop 辅助函数来指定我们的迭代空间为 (m_iters, n_iters),但我们也要求它应该使用 collective_axes= 参数跨集群网格进行分割。

def matmul4(a, b, config):
  ...

  def kernel(...):
    is_lead_block = lax.axis_index("cluster") == 0

    @plgpu.nd_loop((m_iters, n_iters), collective_axes="cluster_grid")
    def _mn_loop(loop_info: plgpu.NDLoopInfo):
      m_index, n_index = loop_info.index
      m_slice = ...
      n_slice = ...

      ...  # Compute + epilogue

核主体计算部分的唯一有意义的修改是确保内存 Warp 中的前几次对 consumed_barriers 的等待仅在处理第一个输出块时跳过(如 loop_info.local_index == 0 所示)。在处理第二个(或更晚)块时,SMEM 缓冲区用于计算前一个输出块,因此我们需要确保在覆盖它们之前,这些计算已经完成。

def matmul4(a, b, config):
  ...
  def kernel(...):
    ...
    def _mn_loop(...):
      ...

      @pl.core_map(plgpu.WarpMesh(axis_name="warp"))
      def _per_warp():
        warp_id = lax.axis_index("warp")

        @pl.when(warp_id == 0)
        def _memory():
          def _loop_body(ki, _):
            slot = lax.rem(ki, max_concurrent_steps)
            @pl.when(jnp.logical_or(ki >= max_concurrent_steps, loop_info.local_index > 0))
            def _():  # Make sure the data has been consumed before overwriting.
              plgpu.barrier_wait(consumed_barriers.at[slot])

最后,我们通过添加一行来修改核的结尾部分

def matmul4(a, b, config):
  ...
  def kernel(...):
    ...
    def _mn_loop(...):
      ...  # Compute + epilogue
      plgpu.wait_load_tmem()  # Load must complete before MMA can overwrite TMEM.

正如注释所示,由于 TMEM 加载是异步的,在移动到下一个输出块并发出另一个 MMA 来覆盖我们的 TMEM 分配之前,我们必须等待它们完成。

5. 专用结尾 Warpgroup#

虽然持久化本身很有用,但它也解锁了另一项优化。当核中的单个 Pallas 线程完成核的计算部分时,它会执行整个结尾部分。然而,这意味着在完成之前,它无法再为 TensorCore 发出任何工作!

这引出了一个简单的解决方案:使用 2 个 Pallas 线程(Warpgroup)!第一个将只专注于获取 MMA 操作数并发出 MMA 操作,而第二个将只执行结尾部分!当然,为了让它们能够并行运行,我们需要双缓冲用于累加器的 TMEM,并使用额外的屏障进行同步。

def matmul5(a, b, config):
  ...

  f = plgpu.kernel(
      ...,
      num_threads=2,
      thread_name="wg",
      scratch_shapes=dict(
          ...
          # Previously: plgpu.TMEM((tile_m, cluster_tile_n), jnp.float32, collective=True),
          acc_tmem=plgpu.TMEM(
              (tile_m, 2 * cluster_tile_n), jnp.float32, collective=True
          ),
          ...
          # mma_done_barrier (now 2 barriers) + a new store_done_barrier (also 2 barriers)
          # Previously: plgpu.Barrier(num_arrivals=1, num_barriers=1, orders_tensor_core=True),
          mma_done_barrier=plgpu.Barrier(
              num_arrivals=1, num_barriers=2, orders_tensor_core=True
          ),
          store_done_barrier=plgpu.ClusterBarrier(
              collective_axes=("cluster",),
              num_arrivals=1,
              num_barriers=2,
              orders_tensor_core=True,
          ),
      ),
  )

核的开头与我们之前的核类似。我们将 acc_tmem 重命名为 acc_tmem_slots,并在循环遍历输出块时在它们的两个半部分之间切换。

def matmul(a, b, config):
  ...

  def kernel(a_gmem, b_gmem, out_gmem,
             a_smem, b_smem, acc_tmem_slots, acc_smem,
             load_barriers, consumed_barriers, mma_done_barrier, store_done_barrier):
    wg_idx = lax.axis_index("wg")
    is_lead_block = ...

    @plgpu.nd_loop(...)
    def _mn_loop(...):
      ...
      acc_slot = lax.rem(loop_info.local_index, jnp.int32(2))
      acc_tmem = acc_tmem_slots.at[:, pl.ds(acc_slot * cluster_tile_n, cluster_tile_n)]

      ...

计算部分另外基于 wg_idx == 0 进行谓词化。在使用屏障的方式上也有两个重要的更改。首先,如果我们想重用我们的 TMEM 分配来进行 MMA(这只发生在 loop_info.local_index >= 2 时),我们需要等待我们想要重用的 TMEM 部分的 store_done_barrier(如 acc_slot 所示)。其次,一旦我们想请求 TensorCore 在完成时到达 mma_done_barrier,我们再次需要选择当前使用的 TMEM 部分对应的两个屏障之一。

警告

请注意,即使集群中只有一个块发出 MMA,它们都会等待 store_done_barrier。这只是必要的,因为在没有 wait 的情况下两次到达同一个屏障有时会导致硬件断言。

def matmul(a, b, config):
  ...
  def kernel(...):
    ...
    def _mn_loop(...):
      acc_slot = ...
      acc_tmem = ...

      @pl.when(wg_idx == 0)
      def _compute_wg():
        @pl.core_map(plgpu.WarpMesh(axis_name="warp"))
        def _per_warp():
          warp_id = lax.axis_index("warp")

          @pl.when(warp_id == 0)
          def _memory():
            ... # Memory code remains unchanged

          # Wait for store to complete (except for the first two steps).
          @pl.when(jnp.logical_and(warp_id == 1, loop_info.local_index >= 2))
          def _wait_store():
            plgpu.barrier_wait(store_done_barrier.at[acc_slot])
          @pl.when(jnp.logical_and(warp_id == 1, is_lead_block))
          def _compute():
            ... # Compute loop remains unchanged
            plgpu.tcgen05_commit_arrive(mma_done_barrier.at[acc_slot], collective_axis="cluster")

最后,我们修改了结尾部分,只让第二个 Warpgroup 执行它,并让 Warpgroup 通过到达与其使用的 TMEM 部分关联的 store_done_barrier 来发出存储完成信号。

def matmul(a, b, config):
  ...
  def kernel(...):
    ...
    def _mn_loop(...):
      ... # Compute

      @pl.when(wg_idx == 1)
      def _store_wg():
        ... # Unmodified epilogue
        plgpu.wait_load_tmem()  # Load must complete before we signal.
        plgpu.barrier_arrive(store_done_barrier.at[acc_slot])

6. 网格分块#

我们对该核的最后一个更改是改变我们生成输出块的顺序,以更好地利用 L2。如前所述,计算单元与内存系统相比速度极快,因此我们可以获得任何帮助来尝试让它们保持忙碌。

注意

这个技巧有很多不同的名称。CUTLASS 称之为“光栅化顺序”,ThunderKittens 称之为“超级分组”,而 Triton 教程称之为“程序重排”。我们使用“网格分块”这个名称。

我们的策略受到 CUTLASS 的启发,工作方式如下。首先,您选择迭代空间中的哪个维度是变化最快的(我们称之为 grid_minor_dim)。然后,您选择该维度上的块大小(grid_tile_width)。我们不是在增加主索引之前遍历网格的整个次要维度,而是在每次遍历 grid_tile_width 个元素时执行此操作。一旦我们用完了元素,我们就进入下一个块。但有一个转折!我们不是跳转到第二个块的开头,而是从末尾开始,向后移动。这确保了在我们切换块时,我们可以重用其中一个操作数的最近块。

由于这种策略非常普遍,我们提供了一个辅助函数:plgpu.planar_snake。使用此辅助函数时,对核的更改非常微不足道

def matmul(a, b, config):
  ...
  def kernel(...):
    ...
    # We now only iterate over a 1D loop (but we still split it across clusters).
    @plgpu.nd_loop((m_iters * n_iters,), collective_axes="cluster_grid")
    def _mn_loop(loop_info: plgpu.NDLoopInfo):
      (lin_idx,) = loop_info.index
      m_index, n_index = plgpu.planar_snake(
          lin_idx,  # Linear index.
          (m_iters, n_iters),  # The 2D iteration space.
          config.grid_minor_dim,  # 0 or 1, indicates the fastest changing dim.
          config.grid_tile_width,  # The width of tiles along the fastest changing dim.
      )
      ... # Rest of the code remains unmodified

这个简单的技巧 *效果惊人*,是实现最先进性能的关键。

最终核#

恭喜您完成了本教程!在前面的部分,我们只关注了不同核之间的差异,很少列出完整的源代码。这在扩展实现时隐藏不相关细节很有用,但看到完整的源代码也可能很有帮助。所以,这是!整个实现不到 150 行,并且达到了最先进的性能(至少在我们基准测试使用的形状上)。

def matmul6(a, b, config: TuningConfig):
  dtype = a.dtype
  m, k = a.shape
  _, n = b.shape
  tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k
  swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8)
  swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
  transforms = (
      plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle)
  )
  if m % tile_m != 0:
    raise ValueError(f"{m=} must be divisible by {tile_m=}")
  if n % tile_n != 0:
    raise ValueError(f"{n=} must be divisible by {tile_n=}")
  if k % tile_k != 0:
    raise ValueError(f"{k=} must be divisible by {tile_k=}")
  cluster_tile_m = 2 * tile_m
  cluster_tile_n = 2 * tile_n
  m_iters = m // cluster_tile_m
  n_iters = n // cluster_tile_n
  k_iters = k // tile_k
  max_concurrent_steps = config.max_concurrent_steps

  def kernel(a_gmem, b_gmem, out_gmem,
             a_smem, b_smem, acc_tmem, acc_smem,
             load_barriers, consumed_barriers, mma_done_barrier, store_done_barrier):
    wg_idx = lax.axis_index("wg")
    is_lead_block = lax.axis_index("cluster") == 0

    @plgpu.nd_loop((m_iters * n_iters,), collective_axes="cluster_grid")
    def _mn_loop(loop_info: plgpu.NDLoopInfo):
      (lin_idx,) = loop_info.index
      m_index, n_index = plgpu.planar_snake(
          lin_idx,
          (m_iters, n_iters),
          config.grid_minor_dim,
          config.grid_tile_width,
      )
      m_slice = pl.ds(m_index * cluster_tile_m, cluster_tile_m)
      n_slice = pl.ds(n_index * cluster_tile_n, cluster_tile_n)
      acc_slot = lax.rem(loop_info.local_index, jnp.int32(2))
      mn_acc_tmem = acc_tmem.at[:, pl.ds(acc_slot * cluster_tile_n, cluster_tile_n)]

      @pl.when(wg_idx == 0)
      def _compute_wg():
        @pl.core_map(plgpu.WarpMesh(axis_name="warp"))
        def _per_warp():
          warp_id = lax.axis_index("warp")

          @pl.when(warp_id == 0)
          def _memory():
            def _loop_body(ki, _):
              slot = lax.rem(ki, max_concurrent_steps)
              @pl.when(jnp.logical_or(ki >= max_concurrent_steps, loop_info.local_index > 0))
              def _():  # Make sure the data has been consumed before overwriting.
                plgpu.barrier_wait(consumed_barriers.at[slot])
              k_slice = pl.ds(ki * tile_k, tile_k)
              plgpu.copy_gmem_to_smem(
                  a_gmem.at[m_slice, k_slice], a_smem.at[slot], load_barriers.at[slot],
                  collective_axes="cluster", partitioned_axis=0
              )
              plgpu.copy_gmem_to_smem(
                  b_gmem.at[k_slice, n_slice], b_smem.at[slot], load_barriers.at[slot],
                  collective_axes="cluster", partitioned_axis=1
              )

            lax.fori_loop(0, k_iters, _loop_body, None)

          # Wait for store to complete (except for the first two steps).
          @pl.when(jnp.logical_and(warp_id == 1, loop_info.local_index >= 2))
          def _wait_store():
            plgpu.barrier_wait(store_done_barrier.at[acc_slot])
          @pl.when(jnp.logical_and(warp_id == 1, is_lead_block))
          def _compute():
            def _loop_body(ki, _):
              slot = lax.rem(ki, max_concurrent_steps)
              plgpu.barrier_wait(load_barriers.at[slot])  # Wait for data to arrive.
              plgpu.tcgen05_mma(
                  mn_acc_tmem,
                  a_smem.at[slot],
                  b_smem.at[slot],
                  consumed_barriers.at[slot],
                  accumulate=(ki > 0),
                  collective_axis="cluster",
              )
            lax.fori_loop(0, k_iters, _loop_body, None)
            plgpu.tcgen05_commit_arrive(
                mma_done_barrier.at[acc_slot],
                collective_axis="cluster",
            )

      @pl.when(wg_idx == 1)
      def _store_wg():
        # Ensure that copies from the previous mn step have completed.
        plgpu.wait_smem_to_gmem(0, wait_read_only=True)
        plgpu.barrier_wait(mma_done_barrier.at[acc_slot])
        out_m_index = m_index * 2 + lax.axis_index("cluster")
        out_m_slice = pl.ds(out_m_index * tile_m, tile_m)
        out_gmem_window = out_gmem.at[out_m_slice, n_slice]
        for ni in range(cluster_tile_n // config.epilogue_tile_n):
          acc_smem_ni = acc_smem.at[ni % 2]
          ni_slice = pl.ds(ni * config.epilogue_tile_n, config.epilogue_tile_n)
          # Make sure that previous copy is done before we overwrite.
          plgpu.wait_smem_to_gmem(1, wait_read_only=True)
          acc_smem_ni[...] = plgpu.async_load_tmem(mn_acc_tmem.at[:, ni_slice]).astype(dtype)
          plgpu.commit_smem()
          plgpu.copy_smem_to_gmem(acc_smem_ni, out_gmem_window.at[:, ni_slice])
        plgpu.wait_load_tmem()  # Load must complete before we signal.
        plgpu.barrier_arrive(store_done_barrier.at[acc_slot])
    plgpu.wait_smem_to_gmem(0, wait_read_only=True)

  num_sms = backend.get_default_device().core_count
  f = plgpu.kernel(
      kernel,
      out_shape=jax.ShapeDtypeStruct((m, n), dtype),
      grid=(num_sms // 2,),
      grid_names=("cluster_grid",),
      cluster=(2,),
      cluster_names=("cluster",),
      num_threads=2,
      thread_name="wg",
      scratch_shapes=dict(
          a_smem=plgpu.SMEM(
              (max_concurrent_steps, tile_m, tile_k), dtype, transforms=transforms
          ),
          b_smem=plgpu.SMEM(
              (max_concurrent_steps, tile_k, tile_n), dtype, transforms=transforms
          ),
          acc_tmem=plgpu.TMEM(
              (tile_m, 2 * cluster_tile_n), jnp.float32, collective=True
          ),
          acc_smem=plgpu.SMEM(
              (2, tile_m, config.epilogue_tile_n), dtype, transforms=transforms
          ),
          load_barriers=plgpu.Barrier(
              num_arrivals=2, num_barriers=max_concurrent_steps
          ),
          consumed_barriers=plgpu.Barrier(
              num_arrivals=1,
              num_barriers=max_concurrent_steps,
              orders_tensor_core=True,
          ),
          mma_done_barrier=plgpu.Barrier(
              num_arrivals=1, num_barriers=2, orders_tensor_core=True
          ),
          store_done_barrier=plgpu.ClusterBarrier(
              collective_axes=("cluster",),
              num_arrivals=1,
              num_barriers=2,
              orders_tensor_core=True,
          ),
      )
  )
  return f(a, b)