使用 Pallas 编写 Mosaic GPU 内核#

此页面是 Pallas:MGPU 后端最重要功能的参考。它不是一个教程,因此我们不期望每个人都从头到尾阅读它。尽管如此,还是值得通读一遍,以便熟悉在其他教程中可以找到的一些模式。

在以下示例中,我们将假设以下导入在作用域内

import jax.experimental.pallas as pl
import jax.experimental.pallas.mosaic_gpu as plgpu

什么是 GPU?#

从技术上讲,NVIDIA GPU 架构如下:GPU 被划分为流式多处理器 (SMs)。这在 CUDA 编程模型中表现为:每个CUDA 线程块(或 CTA)都精确调度到一个 SM 上,但一个 SM 上可以同时调度多个块。

每个 SM 包含一块称为共享内存 (SMEM) 的快速内存,以及 4 个子分区,每个子分区包含一个warp 调度器和计算单元(ALU、TensorCore 等)。这在 CUDA 程序中也有所体现:每个warp(一个块中连续 32 个 CUDA 线程的组)以循环方式分配给其中一个子分区。与块类似,每个 warp 精确分配到一个子分区(它从不迁移),但多个 warp 可以分配到同一个 SM 子分区。在每个时钟周期,每个子分区的 warp 调度器都会尝试选择其常驻 warp 之一来执行下一条指令。

A diagram of one NVIDIA SM

进一步来说,最近的 CUDA 版本还概述了warpgroup的概念,它由 4 个连续的 warp 组成。了解硬件的结构,我们就能明白其来源:4 个连续的 warp 占据一个 SM 的 4 个象限,并允许我们发出利用整个 SM 的指令。

注意

GPU 可以从许多不同的角度来看待,在这里我们想专注于一个稍微简化的、以 TensorCore 为中心的模型。这应该有助于您应对编写涉及 TensorCore 的内核的复杂性,但请记住,实际情况更为复杂。

就我们的目的而言,TensorCore 操作已经变得如此庞大,以至于再遵循 CUDA 模型意义不大。因此,对我们来说,GPU 是单线程核心(SM)的集合,其中一个 Pallas:MGPU 线程对应一个 CUDA warpgroup。在这个模型中,您在内核中执行的每个操作都占用整个 CUDA warpgroup,并且其组成 warps 总是同步运行(除了硬件调度引起的抖动),并且从不通过控制流走不同的路径(除了我们稍后将讨论的 core_map 小例外)。这里一个值得注意的补充是,我们仍然允许您在同一个 SM 上共同调度多个 Pallas 级线程,以便它们可以通过共享内存进行协作和通信(我们通过将它们放在同一个 CUDA 块中来实现这一点)。

注意

从现在开始,每当我们提到“线程”时,我们指的是 Pallas 线程,而不是 CUDA 线程/lane。

注意

这与 Triton 推广的编程模型非常相似,但正如您将看到的,存在一些差异。Mosaic GPU 倾向于更底层,这通常意味着您需要投入更多工作,但它也让您拥有更多控制权。在我们看来,这两种方法都有其优点,我们鼓励您选择最适合您需求的后端!Pallas 支持并将继续支持 Triton 作为另一种 GPU 后端。

顺序执行与使用多个硬件单元#

与更复杂的 CPU 架构不同,GPU 只支持顺序执行。然而,这并不意味着在任何给定时间只有一条指令在运行!每个 SM 季度都有多个独立的函数单元:TensorCore、算术逻辑单元 (ALU)、加载/存储单元 (LSU)、特殊函数单元 (SFU)。如果第一条指令面向其中一个单元,并且后面跟着另一条指令(不使用第一条指令的结果),那么 warp 调度器可以在第一条指令完成之前发出第二条指令。这通常被称为指令级并行 (ILP),是现代 TensorCore 内核中的一个常见主题:TensorCore 操作如此庞大,需要大量周期才能完成,因此不尝试同时使用其他单元是一种浪费。

为了进一步扩展这一点,我们可以通过允许多个 Pallas 线程并发运行来利用这种硬件单元级并行性。如果其中一个线程主要占用 ALU,而另一个主要发出与 TensorCore 相关的指令,我们可以利用 warp 调度器中内置的高效上下文切换来保持两个单元的繁忙。这是 FlashAttention 3CUTLASS ping-pong 矩阵乘法内核等算法背后的核心思想之一。

有关 warp 调度和指令发出工作原理的更多信息,我们建议阅读 《分析现代 NVIDIA GPU 核心》

内存空间#

GPU 具有几种不同的内存空间,它们可以按容量从大到小、速度从慢到快(总带宽和单次访问延迟)进行完全排序。

A diagram of memory spaces of an NVIDIA GPU

最大的内存空间是 plgpu.GMEM,即全局内存。在最近的数据中心级 GPU 中,这个内存空间通常以几十甚至上百 GB 来衡量,但它也是最慢的。

下一个用于 L2 缓存的内存空间,在某种程度上也是全局的,因为它由整个 GPU 共享,但它的使用只能通过缓存提示间接影响。因此,无法手动将值放入其中,所以 Pallas:MGPU 中不暴露此内存空间。虽然只有大约 100MB 大小,但这种内存比 GMEM 具有更高的带宽,因此在编写高性能内核时仍然经常建议利用它。

接下来是共享内存,或 plgpu.SMEM。此内存直接位于每个 SM 内部,因此它是分区化的。除非使用块集群(参见下面的集群部分),否则每个块只允许访问其自己的 SMEM 分配。

最后,最低级别的内存空间是寄存器内存。Pallas 内核中的每个值(即 JAX 数组)都将位于此处。如果编译器耗尽寄存器来存储这些数组,它将插入溢出,这意味着它会定期将值存储到内存并重新加载。这些溢出通常会导致其他显著的性能下降,因此我们建议避免它们。有关溢出的警告消息可以在内核编译期间的 ptxas 消息中清楚地看到。要使它们可见,请在您的环境中运行 MOSAIC_GPU_DUMP_PTXAS=1

Blackwell GPU 一代新增了一个名为张量内存plgpu.TMEM 的内存空间。TMEM 与寄存器内存非常相似,只是它由您显式分配和管理。它用于存储 MMA 累加器、操作数元数据(用于稀疏性或缩放),以及可选的左 MMA 操作数。有关 TMEM 的更多信息,请参见 Blackwell MMA 部分。

在特定内存空间中请求/分配内存#

内核输入或输出默认放置在 SMEM 中。如果想将它们作为 GMEM 引用访问,请将 memory_space=plgpu.GMEM 添加到它们的 BlockSpec 中。如果希望内核在 GMEM 中使用整个输入或输出数组进行调用,只需指定 BlockSpec(memory_space=plgpu.GMEM) 即可。

SMEMTMEM 可以在 pl.pallas_callscratch_shapes 参数中显式分配,或使用 pl.run_scoped。要分配一个引用,只需使用请求的形状和 dtype 调用内存空间对象即可。例如:plgpu.SMEM((128, 128), jnp.float16) 将在共享内存中分配一个 128x128 的 float16 元素数组。

利用 L2 缓存#

虽然 L2 缓存无法手动管理,但其相对于全局内存显著更高的带宽使其值得考虑。利用它的最简单方法是重新排列并行网格维度,以便在相似时间段内调度的调用也能访问相同的输入数据。

虽然 CUDA 编程模型不保证块分配给 SM 的顺序,但在最近的几代中,启发式方法似乎只是简单地以列主序遍历 (x, y, z) CUDA 网格(即 x 是变化最快的维度,z 是最慢的)。类似地,Pallas:MGPU 不保证用户指定的网格如何映射到 CUDA 网格(Pallas 支持任意维度的网格,而不仅仅是 3D)。但是,您可以假定迭代将以行主序发生。也就是说,如果网格的维度为 (a, b),那么 b 将是变化最快的维度,而 a 将是较慢的维度。

为了提供一个实际例子,考虑一个普通的矩阵乘法内核。通常,会使用两个并行网格维度 (m, n),对应于对两个非收缩维度进行分块。如果我们使用这种简单的方案,在 Pallas:MGPU 中,所有 ID 为 (0, ...) 的程序都将在任何 ID 为 (1, ...) 的块之前调度。并且,集体地,m=0 的程序必须读取所有 B 操作数!如果 nk 维度非常大,那么我们几乎不可能从 (1, ...) 程序对 (0, ...) 程序所做的访问中获得缓存命中。为简单起见,假设我们一次只能运行 16 个块,我们从第一个调度波中看到这种访问模式

Your browser does not support SVGs or scripting is disabled. This would be an image showing the access pattern of first 16 blocks without grid tiling.

然而,如果我们简单地将网格重新排列为 (m // mt, n, mt)(然后在内核中将 pl.program_id(0) 替换为 pl.program_id(0) * mt + pl.program_id(2)),那么很容易看出,沿着两个维度的一组程序将并发调度(而不是调度单行)。这极大地增加了加载相似数据切片的并发程序数量,通常会显著提高 L2 利用率,从而提高内核的整体性能(如果它是内存绑定的)。继续我们的例子,使用 16 个块和 mt=4,我们得到以下访问模式

Your browser does not support SVGs or scripting is disabled. This would be an image showing the access pattern of first 16 blocks with grid tiling.

请注意,即使活动块的数量没有改变,它们访问的数据的总占用空间也减少了一半!我们现在有更大的机会获得 L2 命中。

数组布局和内存引用转换#

在 Pallas 中,您处理的数据结构(数组和引用)具有**逻辑形状**(例如,一个 128x128 矩阵)。此逻辑形状必须映射到**物理表示**(数据在 GPU 内存中的实际表示方式)。具体的映射取决于数据驻留在何处

  1. 数组布局: 数组存储在寄存器内存中,我们将此映射称为布局。布局定义了数组元素如何分布在构成 Pallas 线程的 CUDA lane 可用的寄存器中。

  2. 内存引用转换: 对于指向 SMEM 的可变引用,此映射称为转换。转换描述了逻辑数据结构在该内存块内的排列方式。

这些概念对于性能至关重要,特别是在与 TensorCore 等专用硬件单元交互或优化内存访问模式时。

注意

我们正在开发一种模式,将全自动处理布局和转换的分配(尽管提供提示和更多控制的方式)。下面列出的 API 可能会继续运行,但将变为可选。

内存引用转换#

转换在内存引用首次分配时应用。对这些引用进行操作的 Pallas 原语将自动考虑其关联的转换。

def body(..., scratch_ref):
  # Asynchronous copy will reformat the GMEM data to match the SMEM transforms
  plgpu.copy_gmem_to_smem(..., scratch_ref, barrier)
  barrier.wait()
  plgpu.wgmma(..., scratch_ref)  # wgmma only accepts properly transformed refs
  ...

引用有两种分配方式,每种方式都有选择所需转换的方法

1. 使用 plgpu.BlockSpec

transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128))
f = pl.pallas_call(
  in_specs=plgpu.BlockSpec(in_block_shape, in_index_map, transforms=transforms),
  out_specs=plgpu.BlockSpec(out_block_shape, out_index_map, transforms=transforms),
  ...
)

请注意,与 plgpu.BlockSpec 不同,pl.BlockSpec 允许指定转换。

2. 在已分配的 SMEM 上指定 transforms 参数

transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128))
f = pl.pallas_call(
  scratch_shapes=plgpu.SMEM((128, 128), jnp.float16, transforms=transforms),
  ...
)

可用的转换有

  • plgpu.TileTransform(tile_shape) 将数据组织成形状为 tile_shape 的连续、不重叠的瓦片。一个瓦片的数据总是先完全线性化(行主序),然后另一个瓦片才开始(瓦片也以行主序遍历)。例如,将 TileTransform((8, 64)) 应用于 (128, 128) 引用意味着对应于逻辑切片 [0:8, 0:64] 的数据将首先存储(行主序),然后是 [0:8, 64:128], [8:16, 0:64], [8:16, 64:128],依此类推。另一种实现方式是获取输入数组 x 并以行主序遍历 x.reshape(128 // 8, 128 // 64, 8, 64).transpose(0, 2, 1, 3)

  • plgpu.SwizzleTransform(swizzle_in_bytes),它按照 PTX 文档CUDA 文档 中描述的方式转换数据。交错(Swizzling)很有用,因为它允许在寄存器和共享内存之间传输 MMA 相关布局的数据而不会发生 bank 冲突。交错后内存确切的样子不是那么重要,因为所有原语都会自动考虑它。请注意,交错量以字节为单位指定(仅支持 128、64、32 和 16),并且通常伴随着一个 TileTransform(它使用其形状中的元素!)。

  • plgpu.TransposeTransform(permutation),它在数组线性化之前置换数组的维度。这主要有用,因为它允许您在 GMEM-SMEM 复制期间更改布局(请记住,硬件不支持更改最小/最后一个维度)。

数组布局#

到目前为止,我们已经为您定义了一些有用的布局

  • plgpu.Layout.WGMMA,这是 Hopper 代 TensorCore 期望 MMA 累加器或 16 位输入操作数在寄存器中拥有的布局。

  • plgpu.Layout.WGMMA_ROW,这是在上述布局沿行缩减后获得的布局。重新广播行是免费的,并将生成一个具有 WGMMA 布局的值。

  • plgpu.Layout.WGMMA_COL,它是上述布局的模拟,只是沿列而不是行缩减。

  • plgpu.Layout.WG_STRIDED,其中值在组成 Pallas 线程的 128 个 CUDA lane 之间平均分配。连续的元素(向量化后)以循环方式分配给 lane。当不需要与 TensorCore 交互时,这种方法非常简单有效。

  • plgpu.Layout.WG_SPLAT,表示该值是常量。每个 CUDA lane 将持有一个包含该值的寄存器。您通常不需要与此布局交互,因为它在创建常量值时隐式使用,并且总是可以隐式转换为其他布局。

目前,在默认操作模式下,数组布局传播仅向前进行,并且对协调布局冲突的隐式支持很少:只有 splat 布局可以隐式转换为任何其他布局。例如,如果您尝试添加两个具有不同布局的数组,降级将报错并失败。在布局之间进行转换的功能非常有限,我们通常建议将值存储到 SMEM 中,然后以目标布局读回。

MMA (TensorCore)#

在本节中,我们重点介绍 Pallas:MGPU 内核如何利用 TensorCore 单元。TensorCore 的编程接口在不同的 NVIDIA GPU 代之间差异很大,这就是 Pallas:MGPU 中最低级别接口也不同的原因。

每个 MMA 操作都与三个操作数相关联

  • 形状为 (M, N) 的累加器 D

  • 形状为 (M, K) 的左输入 A

  • 形状为 (K, N) 的右输入 B。所有操作数必须具有相同的元素类型。

每次使用 MMA 都涉及几个步骤

  1. 分配累加器的空间(MMA 隐式执行 D += A @ B

  2. 准备 A 和 B 操作数

  3. 发出操作

  4. 等待操作完成

  5. 读出结果

步骤 2-4 通常在收缩维度 (K) 上循环执行。

A 和 B 操作数的内存空间#

AB 操作数通常最好通过 SMEM 传入,在那里可以使用 plgpu.copy_gmem_to_smem 方便地加载。为了使这些操作数与 MMA 操作兼容,需要在分配时指定适当的 tiling 和 swizzling 转换。对于目前所有支持的代,TensorCore 要求数据以行主序 2D 瓦片形式布局,形状为 (8, swizzle_elems),其中 swizzle_elems 是通过将 swizzle 除以元素类型字节宽度得出的。目前支持的 swizzles 有:128、64 和 32。更大的 swizzles 更受青睐,因为它们能提高 GMEM 到 SMEM 复制的性能。

def mma_transforms(shape_dtype: jax.ShapeDtypeStruct):
  assert len(shape_dtype.shape) == 2
  if shape_dtype.shape[0] % 8:
    raise ValueError("Number of rows must be divisible by 8")
  for swizzle_bytes in (128, 64, 32):
    swizzle_elems = swizzle_bytes // shape_dtype.dtype.itemsize
    if shape_dtype.shape[-1] % swizzle_elems == 0:
      return (plgpu.TilingTransform((8, swizzle_elems)),
              plgpu.SwizzleTransform(swizzle_bytes))
  raise ValueError("Failed to find transforms for the specified window type")

如果操作数需要转换,A 操作数可以通过不同的内存空间传入(取决于架构,见下文)。B 操作数必须位于 SMEM 中。

转置操作数#

当对 16 位操作数执行 MMA 时,TensorCore 可以自动转置输入数据。例如,A 引用允许的形状为 (K, M),但在将其传递给 mma 函数之前,必须对其进行转置。例如

assert acc_ref.shape == (M, N) and a_ref.shape == (K, M) and b_ref.shape == (K, N)
a_ref_t = plgpu.transpose_ref(a_ref, (1, 0))
assert a_ref_t.shape == (M, K)  # The shape expected by plgpu.wgmma
plgpu.wgmma(acc, a_ref_t, b_ref)

在这种情况下,B 引用也允许进行类似的操作。

Hopper (wgmma)#

在本节中,我们介绍使用 Hopper 代 TensorCore 的基础知识,它在 PTX 中以 wgmma.mma_async 指令 的形式公开。

分配累加器#

在 Hopper 硬件架构中,累加器在寄存器中分配,但在 Pallas 中,它被建模为可变引用,因为每个 MMA 操作都原地累积。有两种方式可以分配累加器。

要创建一个零初始化的累加器,可以使用 pl.run_scoped 并指定 plgpu.ACC((m, n), dtype) 类型。

def compute(acc_ref):
  ...
  return acc_ref[...]
output = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32))

对累加器引用进行解引用,如 compute 函数末尾所示,将隐式等待所有未完成的 WGMMA 操作。

如果您想用现有数组初始化它,可以使用 pl.run_stateplgpu.ACC.init(init_array)

def compute(acc_ref):
  ...
  return # pl.run_state only returns the final value of the accumulator
output = pl.run_state(compute)(plgpu.ACC.init(init_array))

如果 pl.run_state 具有累加器操作数,它会在返回最终值之前隐式等待所有未完成的 WGMMA 操作。

准备 A 和 B 操作数#

如上所述,我们建议通过共享内存传入 AB。在这种情况下,必须指定正确的 tiling 和 swizzling 转换。

plgpu.wgmma 还允许通过寄存器传入 A(即不是 SMEM 引用,而是作为常规 JAX 数组)。然而,这种模式存在许多显著的缺点,并且很难确保足够的同步以使其安全。

待办:解释在何种条件下可以这样做。

发出操作#

支持的 MMA 形状如下

  • M 可被 64 整除

  • N 可被 8 整除且小于 256

  • Kswizzle 除以元素类型字节宽度的倍数

目前支持的数据类型有:jnp.float32jnp.bfloat16jnp.float16。累加器 D 必须是 jnp.float32,但 jnp.float16 输入除外,在这种情况下也允许是 jnp.float16

等待操作完成#

每个 plgpu.wgmma 调用都会隐式地与所有先前的 plgpu.wgmma 调用同步,以便一旦控制权从它返回,我们就能保证除了最后发出的 WGMMA 之外,没有其他 WGMMA 仍在运行。因此,任何先前发出的 WGMMA 指令读取的 SMEM 区域都可以被重用。这对于 WGMMA 与异步内存复制的流水线化尤其重要

buffers = 3  # In reality you might want even more
assert a_smem.shape == (buffers, m, k)
assert b_smem.shape == (buffers, k, n)
assert acc_ref.shape == (m, n)

def fetch_a_b(ki, slot):
  a_slice = ... # Replace with the right M/K slice
  b_slice = ... # Replace with the right K/N slice
  plgpu.copy_gmem_to_smem(a_gmem.at[a_slice], a_smem.at[slot], a_loaded.at[slot])
  plgpu.copy_gmem_to_smem(b_gmem.at[b_slice], b_smem.at[slot], b_loaded.at[slot])

def loop_body(i, _):
  slot = jax.lax.rem(i, buffers)
  plgpu.barrier_wait(a_loaded.at[slot])
  plgpu.barrier_wait(b_loaded.at[slot])
  plgpu.wgmma(acc_ref, a_smem.at[slot], b_smem.at[slot])
  # We know that only the last issued WGMMA is running, so we can issue a async load in
  # into the other buffer
  load_i = i + buffers - 1
  load_slot = jax.lax.rem(load_i, buffers)
  @pl.when(jnp.logical_and(load_i >= buffers, load_i < num_steps))
  def _do_fetch():
    fetch_a_b(load_i, slot)
for slot in range(buffers):
  fetch_a_b(slot, slot)
jax.lax.fori_loop(0, num_steps, loop_body, None)

Blackwell (tcgen05)#

Blackwell 一代显著重新设计了 TensorCore 子单元。它现在与常规 warp 调度器显著更独立,并且不再使用甚至不支持将寄存器作为其操作数。取而代之的是,引入了一个名为张量内存 (TMEM) 的新内存空间。此外,成对的 SMs 中的 TensorCore 现在可以汇集它们的资源,计算跨越两个 SMs 的更大 MMA 操作。我们称之为“集体 MMA 操作”

分配累加器 / 使用 TMEM#

TMEM 引用可以与所有其他引用以相同的方式分配——使用 pl.run_scoped

@functools.partial(pl.run_scoped, tmem_ref=plgpu.TMEM((128, 128), jnp.float32))
def barrier_scope(tmem_ref):
  ...

并非所有形状都可以在 TMEM 中分配。目前只支持 2D 引用,并且行数(第一个维度的大小)必须为 128 或 64。

此外,如果数据类型的位宽小于 32 位,则需要声明分配是否应该被打包(例如,将两个 16 位元素放入 TMEM 中的一个 32 位单元格)或不打包(每个元素填充到 32 位)。MMA 累加器(fp32 或 fp16)从不打包,但如果左操作数通过 TMEM 传入,则必须始终打包。

@functools.partial(pl.run_scoped,
                   acc_ref=plgpu.TMEM((128, 128), jnp.float16, packed=False),
                   lhs_ref=plgpu.TMEM((128, 128), jnp.float16, packed=True))
def barrier_scope(acc_ref, lhs_ref):
  plgpu.tcgen05_mma(acc_ref, lhs_ref, rhs_smem_ref, ...)
  ...

TMEM 的另一个有趣的复杂性是,所有对其进行的操作都是异步的。因此,通常用于 SMEM 的 Python 下标语法进行的读写操作不适用于 TMEM。

加载#

加载可以使用 plgpu.async_load_tmem 执行,并使用 plgpu.wait_load_tmem 等待。

smem_ref[...] = plgpu.async_load_tmem(tmem_ref)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(smem_ref, gmem_ref)
plgpu.wait_smem_to_gmem(0)
plgpu.wait_load_tmem()  # Wait for the read to fully complete before we overwrite tmem_ref again.

加载语义相当令人困惑,因为从加载返回的数组可以在没有任何额外同步的情况下安全使用。然而,如果读取的 TMEM 区域再次被覆盖(例如通过存储或 MMA 操作),发出加载的线程必须首先调用 plgpu.wait_load_tmem() 以确保程序不会出现竞争条件。

注意

一种解决这种看似违反因果关系的行为(数据在完全从 TMEM 读取之前到达寄存器)的方法是,将其视为 PTX 编译器中限制和便利特性相互作用的结果。我们不知道这是否属实,但至少它说得通。

便利功能是,编译器可以可靠地跟踪 TMEM 加载产生的寄存器使用情况,并将插入最少数量的必要延迟,以确保数据在使用前从 TMEM 到达。读取操作被展开成许多指令,这意味着在开始使用第一个加载填充的寄存器之前,不必等待所有这些指令。这就是为什么我们不需要保护结果的使用。

限制在于编译器无法可靠地对 TMEM 加载和存储执行别名分析,这就是为什么任何未被显式等待分隔的加载和存储都被认为是安全的并发执行。否则,将不必要地降低真正不相关的加载和存储的性能。这就是为什么我们需要在再次使用 TMEM 之前显式等待。

存储#

相反,存储使用 plgpu.async_store_tmem 执行,并使用 plgpu.commit_tmem 等待。

plgpu.async_store_tmem(tmem_ref, smem_ref[...])
plgpu.commit_tmem()
smem_ref2[...] = plgpu.async_load_tmem(tmem_ref)  # Safe to read from tmem_ref now

准备 A 和 B 操作数#

我们建议通过共享内存传入 AB。在这种情况下,必须指定正确的 tiling 和 swizzling 转换A 操作数也可以作为 TMEM 引用传入,但必须打包。

发出操作#

支持的**非集体** MMA 形状如下

  • M 为 64 或 128

  • N 可被 8 整除且小于 512

  • K8 * swizzle 除以元素类型位宽的倍数

支持的**集体** MMA 形状如下

  • M 为 128 或 256(每个块的一半)

  • N 可被 8 整除且小于 256(每个块中小于 128)

  • K8 * swizzle 除以元素类型位宽的倍数

目前支持的浮点数据类型有:jnp.bfloat16jnp.float16jnp.float8_e5m2jnp.float8_e4m3fn。累加器可以是 jnp.float32jnp.float16,但 jnp.bfloat16 时必须是 jnp.float32

目前唯一支持的整数数据类型是 jnp.int8,使用 jnp.int32 累加器。

注意

根据我们的基准测试,以下是一些性能经验法则

  • 非集体 MMA 应始终使用 M=128 和 N >= 128。

    • M=64 会导致显著的性能下降。

    • N=64 会导致明显的性能下降,但不如 M=64 显著。

  • 集体 MMA 总是相当快,但并不比非集体 MMA 快。

    • 集体 MMA 的最大好处不是更高的 TensorCore 吞吐量,而是 SM 之间共享数据的能力,从而可以提高内核的算术强度。

  • 交错和转置似乎不会显著影响性能。

等待操作完成#

等待 plgpu.tcgen05_mma 调用的结果需要使用 Barrier。我们建议阅读 Barriers 的参考文档,特别是其Blackwell 相关子部分以获取更多信息。

如果 barrier 直接传递给 plgpu.tcgen05_mma,则在该 barrier 上完成等待将指示最终累加器已写入 TMEM。例如

@functools.partial(pl.run_scoped, barrier_ref=plgpu.Barrier(orders_tensor_core=True))
def barrier_scope(barrier_ref):
  plgpu.tcgen05_mma(acc_tmem, lhs_ref, rhs_ref, barrier_ref, accumulate=False)
  plgpu.barrier_wait(barrier_ref)
  # We can read the result now.
  result = plgpu.async_load_tmem(acc_tmem)
  ...

如果没有 barrier 传递给 plgpu.tcgen05_mma,其完成将在调用 plgpu.tcgen05_commit 后才被跟踪。

@functools.partial(pl.run_scoped, barrier_ref=plgpu.Barrier(orders_tensor_core=True))
def barrier_scope(barrier_ref):
  plgpu.tcgen05_mma(acc_tmem, lhs_ref, rhs_ref, accumulate=False)
  plgpu.tcgen05_mma(acc_tmem, lhs_ref2, rhs_ref2)
  plgpu.tcgen05_commit(barrier_ref)
  plgpu.barrier_wait(barrier_ref)
  # We can read the result now. Both MMAs have completed.
  result = plgpu.async_load_tmem(acc_tmem)
  ...

集体 MMA#

Blackwell 一代获得了一种执行 MMA 操作的新方式,即集群中 2 个 SM 的 TensorCore 协作执行单个 MMA 操作。每个 SM 的 B 操作数与其他 SM 共享。DA 操作数对每个 SM 来说是本地的,不共享。

A diagram showing the partitioning of operands in a collective MMA

这意味着,要执行形状为 M、N、K 的集体 MMA,两个 Pallas 线程中每个线程的操作数大小应为:(M // 2, K) 用于 A(K, N // 2) 用于 B(M // 2, N) 用于 D(累加器)。将这两个累加器堆叠起来将恢复执行 MxNxK 矩阵乘法的结果。

为了简化 B 操作数的加载,plgpu.copy_gmem_to_smem 可以与 collective_axespartitioned_axis 一起使用,以指示沿着集体轴的两个 Pallas 线程应该加载相同的切片,但每个线程只获得其一半。与仅使用 collective_axes 的复制不同,它不利用 TMA 组播(因为每个线程加载的是不同的数据切片),但它可以稍微简化索引逻辑。

plgpu.copy_gmem_to_smem(
    b_gmem,  # [K, N]
    b_smem,  # [K, N // 2]
    b_tma_barrier,
    collective_axes="x",
    partitioned_axis=1,
)

使用 core_map#

pl.pallas_call 适用于单个 Pallas 线程可以执行整个 CUDA 块的全部计算的内核。pl.core_map 函数放宽了这一限制,允许在单个块内使用多个线程(例如用于 warp 专用化)或在块集群中的多个块之间使用多个线程(例如利用组播 TMA)。

pl.pallas_call 替换为 pl.core_mapplgpu.kernel#

让我们从一个简单的 Pallas 内核开始,它用于递增数组

@functools.partial(
  pl.pallas_call,
  grid=(2,),
  in_specs=[pl.BlockSpec(block_shape=(128,), index_map=lambda i: (i,))],
  out_specs=pl.BlockSpec(block_shape=(128,), index_map=lambda i: (i,))
  out_shape=jax.ShapeDtypeStruct((256,), jnp.float32), # Total output shape
)
def run_kernel(x_ref, y_ref):
  # x_ref and y_ref are in SMEM!
  y_ref[...] = x_ref[...] + 1

x = jnp.arange(256, jnp.float32)
y = run_kernel(x)
np.testing.assert_array_equal(y, x + 1)

我们可以使用 pl.core_map 编写类似的内核。一个很大的区别是,与 pl.pallas_call 不同,不会自动插入 GMEM<->SMEM 副本。如果您需要它们,可以自己插入或使用 plgpu.emit_pipeline 助手。我们建议您查阅软件流水线指南

@pl.run_state
def run_kernel(refs):
  x_ref, y_ref = refs
  # Here, we're not in the kernel yet! pl.run_state simply changes the JAX
  # immutable arrays into mutable GMEM (not SMEM!) references.

  # Define the mesh: 2 CUDA blocks over 1 axis called "x"
  mesh = plgpu.Mesh(grid=(2,), grid_names=("x",))

  @pl.core_map(mesh)  # core_map executes the body
  def kernel_body():
    # Once we enter the pl.core_map scope, we are in the body of the kernel.
    block_slice = pl.ds(lax.axis_index("x") * 128, 128)
    y_ref[block_slice] = x_ref[block_slice] + 1

x = jnp.arange(256, jnp.float32)
y_init = jnp.zeros_like(x)
_, y = run_kernel(x, y_init)
np.testing.assert_array_equal(y, x + 1)

虽然 pl.core_map 是一个强大的 API,但它也相当底层,并且几乎总是用于 pl.run_state(将 JAX 数组转换为引用)或 pl.run_scoped(为临时引用分配空间)之下。因此,我们还提供了一个方便的 API plgpu.kernel

mesh = plgpu.Mesh(grid=(2,), grid_names=("x",))

@functools.partial(
    plgpu.kernel,
    out_shape=jax.ShapeDtypeStruct((256,), jnp.float32),
    mesh=mesh
)
def run_kernel(x_ref, y_ref):
  # x_ref and y_ref are in GMEM!
  block_slice = pl.ds(lax.axis_index("x") * 128, 128)
  y_ref[block_slice] = x_ref[block_slice] + 1

x = jnp.arange(256, jnp.float32)
y = run_kernel(x)  # No need to preallocate outputs as in pl.core_map.
np.testing.assert_array_equal(y, x + 1)

注意

pl.core_map 一起使用的 plgpu.Mesh 定义了单个 GPU 内的计算拓扑,指定了工作如何在 CUDA 块(grid)、块内的 Pallas 线程(num_threads)以及可能在 CUDA 块集群(cluster)中分布。这类似于 jax.sharding.Mesh 如何在 JAX 中定义跨多个设备的分布式计算拓扑。两者都涉及 SPMD 程序在定义的拓扑上执行。此外,您可以在 Pallas 线程和集群上运行“集体操作”(例如,使用 plgpu.ClusterBarrier 或集体异步复制),这类似于 JAX 集体操作(psumall_gather 等)如何在 JAX Mesh 中的设备之间操作。两者也都使用命名轴,并且可以使用 lax.axis_index(axis_name) 来获取线程或块的坐标。

每个 CUDA 块使用多个 Pallas 线程#

下面,您可以找到一个示例,其中单个块内的两个 Pallas 线程通过屏障同步,甚至通过 SMEM 交换数据。

mesh = plgpu.Mesh(num_threads=2, thread_name="pallas_thread")
x = jnp.arange(128, jnp.float32)

@functools.partial(
  plgpu.kernel, out_shape=x, mesh=mesh,
  scratch_shapes=[plgpu.SMEM(x.shape, x.dtype), plgpu.Barrier()]
)
def run_kernel(x_ref, y_ref, smem_ref, barrier_ref):
  thread_id = jax.lax.axis_index("pallas_thread")

  @pl.when(thread_id == 0)
  def producer_thread():
    smem_ref[...] = x_ref[...] + 1
    plgpu.barrier_arrive(barrier_ref)  # Signal the consumer thread

  @pl.when(thread_id == 1)
  def consumer_thread():
    plgpu.barrier_wait(barrier_ref)  # Wait for the producer thread
    out_ref[...] = smem_ref[...] + 1

y = run_kernel(x)  # There's no need to preallocate the input anymore.
np.testing.assert_array_equal(y, x + 2)

虽然这个例子很简单,但您可以在同步部分找到一个更复杂的例子。

多个线程常用于高性能内核,例如最新的 Flash Attention 变体或乒乓矩阵乘法。在这两种情况下,程序中有 2 个计算线程以交替方式使用 SM 的 ALU 和 TensorCore,以确保没有执行冲突。

另一种常见技术是分配一个 Pallas 线程,并将其完全用于调度其他线程使用的数据的异步复制。虽然从头开始实现此方案可能很复杂,但我们提供了一个方便的辅助 API:plgpu.emit_pipeline_warp_specialized

使用 CUDA 块集群#

下面的内核启动一个包含 2 个 CUDA 块的单个集群,并使用 TMA 组播功能共同将 GMEM 复制到两个块的 SMEM 中。参与集体复制的所有块必须调度完全相同的复制,程序才能有效。

mesh = plgpu.Mesh(cluster=(2,), cluster_names=("cluster",))

@functools.partial(
  plgpu.kernel,
  out_shape=jax.ShapeDtypeStruct((2, 128), jnp.float32),
  mesh=mesh,
  scratch_shapes=[plgpu.SMEM((128,), jnp.float32), plgpu.Barrier()]
)
def run_kernel(x_ref, y_ref, smem_ref, barrier_ref):
  # Specifying collective_axes will enable TMA multicast automatically.
  plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref, collective_axes="cluster")
  plgpu.barrier_wait(barrier_ref)
  plgpu.copy_smem_to_gmem(smem_ref, o_ref.at[lax.axis_index("cluster")])
  plgpu.wait_smem_to_gmem(0)

x = jnp.arange(128, jnp.float32)
y = run_kernel(x)
# Each block gets the same data and writes it out.
np.testing.assert_array_equal(y, jnp.stack([x, x], axis=0))

pl.run_scoped 中的集体分配#

当使用具有多个 Pallas 线程(即 plgpu.Mesh 中的 num_threads > 1)的 pl.core_map 时,通过 pl.run_scoped 进行的分配(针对 SMEM 或 Barriers)必须由所有线程集体执行。这通过向 run_scoped 指定 collective_axis 参数来指示,这有两个效果

  1. 它保证所有线程都将调用相同的分配,并且

  2. 所有线程都将收到完全相同的分配。

如果未指定 collective_axes 或其不包含 Pallas 线程轴,则每个线程将获得其自己的 scratch 变量的私有副本。这通常是不希望的,并且目前不支持。

同步结构和原语#

在本节中,我们将介绍用于线程间同步以及一些异步操作的最重要的函数和数据结构。

commit_smem#

对引用的常规读/写操作保证产生与顺序程序顺序一致的值。例如,在以下程序中,value 保证等于 value2

ref[...] = value
value2 = ref[...]

然而,此保证不适用于异步原语,例如异步复制或 MMA 操作。为了使 SMEM 写入对这些原语可见,您需要使用 plgpu.commit_smem() 函数显式与它们同步。

例如

smem_ref[...] = value
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(smem_ref, ...)

smem_ref[...] = value
plgpu.commit_smem()
plgpu.wgmma(smem_ref, ...)

未能调用此函数很可能导致微妙的数据竞争,因为这些异步硬件单元会从 SMEM 读取过时数据。不幸的是,此函数相对昂贵,这就是为什么我们依靠您(用户)在必要的最少位置插入它的原因。

Barrier#

这本质上是 PTX mbarrier 类型数组的薄包装,并作为引用传入。所有涉及 barrier 的函数都期望只获得一个 barrier 参数,因此如果引用包含多个,您必须使用 barriers.at[index] 显式提取其中一个。Barriers 始终在 SMEM 中分配,因此开销相对较低。每个 barrier 可以配置为在固定数量的“到达”(默认为 1)后完成。

要阻塞一个线程直到 barrier 完成,请使用以下函数

plgpu.barrier_wait(barrier)

警告

至关重要的是,要确保同步方案使得在两次 barrier 完成之间没有调用 plgpu.barrier_wait 的情况下,不可能发生两次 barrier 完成。例如,如果您使用 Barriers 同步两个生产者/消费者线程,您需要双向执行 barrier 同步以引入“背压”,这将阻止一个线程在另一个线程有机会等待之前到达两次。未能满足这一点将损坏数据结构并可能导致意外故障(包括 CUDA 运行时错误)。请参阅下面带有两个线程的有效程序示例。

警告

另一个关键限制是,在 barrier 的整个生命周期中,barrier 完成的数量必须等于 barrier 等待的数量。不允许在 barrier 存在未等待的完成时结束其作用域分配。否则,当编译器重用它时,将其保留在此状态可能会导致下游问题。

警告

最后,至关重要的是,要确保每个曾经等待 Barrier 的线程都参与到其所有 wait 操作中。不允许例如一个线程等待 barrier 的每隔一次完成,而另一个线程等待所有其他完成。这样做将导致死锁。总结一下:当 Barrier 在某个线程中用于等待时,它必须观察该 barrier 的每一个完成(通过等待它)。

请注意,Barrier 可以接收来自任何来源的到达,没有任何限制。

有三个操作可以完成一个 barrier

异步 GMEM 到 SMEM 复制#

当 TMA 引擎正在执行异步 GMEM 到 SMEM 复制时,它会将进度更新发布到传递给 plgpu.copy_gmem_to_smem 的 barrier。一旦复制完成,该 barrier 也会完成一次到达。

显式到达(跨线程同步)#

任何线程都可以使用以下函数显式地到达一个 barrier

plgpu.barrier_arrive(barrier)

这在同步处于生产者/消费者角色的两个线程时特别有用。在这种情况下,我们建议分配两个 Barrier 数组,其大小等于用于在两个线程之间传递数据的“队列”的大小。例如,假设一个线程继续将数组的瓦片写入 SMEM,而另一个线程读取它们。我们对 SMEM 区域进行三重缓冲,以允许两个线程之间有更多的异步性

tid = jax.lax.axis_index("thread")
assert queue.shape == (buffering, *item_shape)
assert produced.shape == consumed.shape == (buffering,)

def thread0_body(i, _):
  slot = jax.lax.rem(i, buffering)
  @pl.when(i >= buffering)
  def _await_consumed():
    plgpu.barrier_wait(consumed.at[slot])  # Wait for consumption of the value before overwriting it
  # Option 1: Compute the next value
  queue[slot] = produce()
  plgpu.barrier_arrive(produced.at[slot])  # Signal the value is ready
  # Option 2: Produce the value through async_copy
  # plgpu.copy_gmem_to_smem(..., queue.at[slot], barrier=produced.at[slot])
pl.when(tid == 0)(lambda: jax.lax.fori_loop(0, steps, thread0_body, None))

def thread1_body(i, _):
  slot = jax.lax.rem(i, buffering)
  plgpu.barrier_wait(produced.at[slot])  # Wait for the value to be ready
  consume(queue[slot])  # Load and compute
  plgpu.barrier_arrive(consumed.at[slot])  # Signal that the value is consumed
pl.when(tid == 1)(lambda: jax.lax.fori_loop(0, steps, thread1_body, None))

等待 tcgen05 TensorCore 指令#

在我们开始之前,一个重要的警告

警告

在 Blackwell 代 GPU 上,Barrier 操作默认对 TensorCore 操作具有宽松的语义。这意味着默认情况下,任何与 TensorCore 相关的操作(包括 TMEM 操作)都可以被编译器移动到屏障信号之后。类似地,任何与 TensorCore 相关的操作都可以被移动到屏障等待之前

如果您打算使用 Barriers 来向其他线程指示 TensorCore 操作已完成,请使用 orders_tensor_core=True 分配 barrier。此参数将插入必要的指令,以防止上述有问题的重新排序。

与旧版 GPU 不同,观察 Blackwell 代 TensorCore 指令完成的唯一方法是将 Barrier 引用传递给 plgpu.tcgen05_mma 函数。一旦 MMA 完成,TensorCore 将到达 barrier。

请注意,Barrier 的这种使用方式要求它们在创建时指定 orders_tensor_core=True,因为它们用于与 TensorCore 操作同步。

@functools.partial(pl.run_scoped, barrier_ref=plgpu.Barrier(orders_tensor_core=True))
def barrier_scope(barrier_ref):
  plgpu.tcgen05_mma(acc_tmem, lhs_ref, rhs_ref, barrier_ref, accumulate=False)
  plgpu.barrier_wait(barrier_ref)
  # We can read the result now
  result = plgpu.async_load_tmem(acc_tmem)
  ...

ClusterBarrier

待办

Semaphore

待办

异步复制

待办

内联 Mosaic GPU

待办

编译器参数

待办