使用 Pallas 编写 Mosaic GPU 内核#

本页是 Pallas:MGPU 后端最重要的特性的参考。这不是教程,因此我们不期望每个人都从头到尾阅读它。不过,还是值得浏览一下,以便熟悉其他教程中可能找到的一些模式。

在接下来的示例中,我们将假定以下导入已在作用域内

import jax.experimental.pallas as pl
import jax.experimental.pallas.mosaic_gpu as plgpu

什么是 GPU?#

从技术上讲,NVIDIA GPU 架构如下:GPU 被划分为流式多处理器(SM)。这在 CUDA 编程模型中的体现是,每个CUDA 线程块(或 CTA)被调度到恰好一个 SM 上,但多个块可以同时调度到同一个 SM 上。

每个 SM 包含一个快速内存块,称为共享内存(SMEM),以及 4 个子分区,每个子分区包含一个warp 调度器和计算单元(ALU、TensorCore 等)。这也在 CUDA 程序中有所体现:每个warp(块中连续的 32 个 CUDA 线程的组)以轮询方式分配到这些子分区之一。与块类似,每个 warp 被分配到恰好一个子分区(它永远不会迁移),但多个 warp 可以分配到同一个 SM 子分区。在每个时钟周期,每个子分区的 warp 调度器会尝试选择一个其驻留 warp 来执行下一条指令。

A diagram of one NVIDIA SM

进一步来说,最近的 CUDA 版本还概述了warpgroup的概念,即 4 个连续的 warp。了解硬件的结构,我们可以看到这是从哪里来的:4 个连续的 warp 占据了 SM 的 4 个四分之一,并允许我们发出利用整个 SM 的指令。

注意

GPU 可以从多种不同的角度来看待,在这里我们想关注一个稍微简化的模型,该模型非常以 TensorCore 为中心。这应该有助于您导航编写涉及 TensorCore 的内核的复杂性,但请记住,实际情况更为复杂。

对我们来说,TensorCore 操作已经变得如此庞大,以至于不再有太大意义遵循 CUDA 模型。因此,对我们而言,GPU 是一个由单线程核心(SM)组成的集合,Pallas:MGPU 的一个线程对应一个 CUDA warpgroup。在此模型中,您在内核中执行的每个操作都占据了整个 CUDA warpgroup,其组成 warp 始终同步运行(除了硬件调度引起的抖动),并且永远不会通过控制流走不同的路径(除了稍后将讨论的 core_map 的小例外)。这里一个值得注意的补充是,我们仍然允许您在同一个 SM 上共同调度多个此类 Pallas 线程,以便它们可以通过共享内存进行协作和通信(我们通过将它们放在同一个 CUDA 块中来实现这一点)。

注意

从现在开始,每当我们说到“线程”时,我们指的是 Pallas 线程,而不是 CUDA 线程/车道。

注意

这与 Triton 推广的编程模型非常相似,但正如您将看到的,也有一些区别。Mosaic GPU 倾向于更底层,这通常意味着您需要付出更多努力,但它也让您拥有更多的控制权。在我们看来,这两种方法都有其优点,我们鼓励您选择最适合您需求的后端!Pallas 支持并将继续支持 Triton 作为替代 GPU 后端。

顺序执行与使用多个硬件单元#

与更复杂的 CPU 架构不同,GPU 仅支持顺序执行。然而,这并不意味着任何时候都只运行一条指令!每个 SM 四分之一都有多个独立的函数单元:TensorCore、算术逻辑单元(ALU)、加载/存储(LSU)、特殊函数单元(SFU)。如果第一条指令针对其中一个单元,并且后面跟着另一条指令(不使用第一条指令的结果),那么 warp 调度器就可以在第一条指令完成之前发出第二条指令。这通常被称为指令级并行(ILP),是现代 TensorCore 内核中的一个常见主题:TensorCore 操作非常庞大且需要很多周期才能完成,因此在此期间不利用其他单元是一种浪费。

为了进一步扩展这一点,我们可以通过允许多个 Pallas 线程并发运行来利用这种硬件单元级并行性。如果一个线程主要占用 ALU,而另一个线程主要发出与 TensorCore 相关的指令,我们可以利用 warp 调度器内置的高效上下文切换来保持两个单元的忙碌。这是 FlashAttention 3CUTLASS ping-pong matmul 内核 等算法背后的核心思想之一。

有关 warp 调度和指令发出的更多信息,我们建议阅读 Analyzing Modern NVIDIA GPU cores

内存空间#

GPU 具有几种不同的内存空间,可以按容量从大到小、从慢到快(在总带宽和单次访问延迟方面)完全排序。

A diagram of memory spaces of an NVIDIA GPU

最大的内存空间是 plgpu.GMEM,即全局内存。在最近的数据中心级 GPU 中,此内存空间的容量通常以几十甚至上百 GB 为单位,但它也是最慢的。

下一个内存空间用于 L2 缓存,它在某种程度上也是全局的,因为它由整个 GPU 共享,但其使用只能通过缓存提示间接影响。因此,没有办法手动将值放入其中,因此该内存空间在 Pallas:MGPU 中未暴露。虽然它只有大约 100MB 大小,但其带宽明显高于 GMEM,因此在编写高性能内核时仍然常常建议利用它。

接下来是共享内存,或 plgpu.SMEM。此内存直接位于每个 SM 内部,因此是分区化的。除非使用块集群(请参阅下面的集群部分),否则每个块只能访问其自己的 SMEM 分配。

最后,最低级别的内存空间是寄存器内存。Pallas 内核中的每个值(即 JAX 数组)都将位于此处。如果编译器用完了存储这些数组的寄存器,它将插入溢出,这意味着它将定期将值存储到内存并重新加载。这些溢出通常会导致其他严重的性能下降,因此我们建议避免它们。在内核编译过程中 ptxas 的消息中可以清楚地看到有关溢出的警告消息。要使其可见,请在环境中运行 MOSAIC_GPU_DUMP_PTXAS=1

Blackwell GPU 代有一个额外的内存空间,称为张量内存plgpu.TMEM。TMEM 与寄存器内存非常相似,只是由您显式分配和管理。它用于存储 MMA 累加器、操作数元数据(用于稀疏性或缩放)以及可选的左 MMA 操作数。有关 TMEM 的更多信息,请参阅 Blackwell MMA 部分。

在特定内存空间中请求/分配内存#

内核输入或输出默认放在 SMEM 中。如果您希望将其作为 GMEM 引用访问,请在它们的 BlockSpec 中添加 memory_space=plgpu.GMEM。如果您希望使用整个输入或输出数组在 GMEM 中调用内核,只需指定 BlockSpec(memory_space=plgpu.GMEM) 即可。

SMEMTMEM 可以在 pl.pallas_callscratch_shapes 参数中显式分配,或使用 pl.run_scoped 进行分配。要分配引用,只需使用请求的形状和 dtype 调用内存空间对象即可。例如:plgpu.SMEM((128, 128), jnp.float16) 将在共享内存中分配一个 128x128 的 float16 元素数组。

利用 L2 缓存#

虽然 L2 缓存无法手动管理,但其明显高于全局内存的带宽使其值得考虑。利用它的最简单方法是重新排序并行网格维度,以便在相似时间段内调度的调用也访问相同的数据输入。

虽然 CUDA 编程模型不保证块分配给 SM 的顺序,但在最近的 GPU 上,启发式方法似乎只是以列主序(即 x 是变化最快的维度,z 是变化最慢的维度)迭代 (x, y, z) CUDA 网格。类似地,Pallas:MGPU 不保证用户指定的网格如何映射到 CUDA 网格(Pallas 支持任意秩的网格,不只是最多 3D)。但是,您可以假定迭代将以行主序进行。也就是说,如果一个网格的维度是 (a, b),那么 b 将是变化最快的维度,而 a 将是较慢的维度。

以一个实际示例为例,考虑一个普通的矩阵乘法内核。在那里,通常使用两个并行网格维度 (m, n),这对应于平铺两个非收缩维度。如果我们使用这个简单的方案,在 Pallas:MGPU 中,所有 id 为 (0, ...) 的程序都将在任何 id 为 (1, ...) 的块之前调度。而且,集体来说,id 为 m=0 的程序必须读取 B 操作数!如果 nk 维度非常大,那么 (1, ...) 程序不可能从 (0, ...) 程序进行的访问中获得缓存命中。为简单起见,假设我们一次只能运行 16 个块,我们从第一个调度的波次看到此访问模式

Your browser does not support SVGs or scripting is disabled. This would be an image showing the access pattern of first 16 blocks without grid tiling.

然而,如果我们只是将网格重新排列为 (m // mt, n, mt)(然后在内核中用 pl.program_id(0) * mt + pl.program_id(2) 替换 pl.program_id(0)),我们可以轻松地看到沿两个维度的一段程序将并发调度(而不是调度单行)。这大大增加了访问相似数据切片的并发程序数量,通常会显著提高 L2 利用率,从而提高内核的整体性能(如果它受内存限制)。继续我们的 16 个块的示例,并使用 mt=4,我们得到以下访问模式

Your browser does not support SVGs or scripting is disabled. This would be an image showing the access pattern of first 16 blocks with grid tiling.

请注意,即使活动块的数量没有改变,它们访问的数据的总占用空间也减半了!我们现在获得 L2 命中的机会大大增加。

数组布局和内存引用变换#

在 Pallas 中,您处理的数据结构(数组和引用)具有逻辑形状(例如,128x128 矩阵)。此逻辑形状必须映射到物理表示(数据在 GPU 内存中实际如何表示)。具体的映射取决于数据的位置。

  1. 数组布局: 数组存储在寄存器内存中,我们将此映射称为布局。布局定义了数组的元素如何在构成 Pallas 线程的 CUDA 车道可用的寄存器中分布。

  2. 内存引用变换: 对于指向 SMEM 的可变引用,此映射称为变换。变换描述了逻辑数据结构在该内存块中是如何排列的。

这些概念对于性能至关重要,尤其是在与 TensorCore 等专用硬件单元交互或优化内存访问模式时。

注意

我们正在开发一种模式,用于完全自动地处理布局和变换的分配(尽管可以通过提示提供更多控制)。下面列出的 API 可能会继续运行,但将成为可选的。

内存引用变换#

变换在内存引用首次分配时应用。操作这些引用的 Pallas 基元将自动考虑其关联的变换。

def body(..., scratch_ref):
  # Asynchronous copy will reformat the GMEM data to match the SMEM transforms
  plgpu.copy_gmem_to_smem(..., scratch_ref, barrier)
  plgpu.barrier_wait(barrier)
  plgpu.wgmma(..., scratch_ref)  # wgmma only accepts properly transformed refs
  ...

有两种分配引用的方式,每种方式都有选择所需变换的方法

1. 使用 plgpu.BlockSpec

transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128))
f = pl.pallas_call(
  in_specs=plgpu.BlockSpec(in_block_shape, in_index_map, transforms=transforms),
  out_specs=plgpu.BlockSpec(out_block_shape, out_index_map, transforms=transforms),
  ...
)

请注意,与 plgpu.BlockSpec 不同,pl.BlockSpec *不允许*指定变换。

2. 在分配的 SMEM 上指定 transforms 参数

transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128))
f = pl.pallas_call(
  scratch_shapes=plgpu.SMEM((128, 128), jnp.float16, transforms=transforms),
  ...
)

可用的变换是

  • plgpu.TileTransform(tile_shape),它将数据组织成形状为 tile_shape 的连续、不重叠的图块。一个图块的数据在另一个图块开始之前总是完全线性化的(行主序)(图块也以行主序遍历)。例如,将 TileTransform((8, 64)) 应用于 (128, 128) 引用意味着对应于逻辑切片 [0:8, 0:64] 的数据将首先存储(行主序),然后是 [0:8, 64:128], [8:16, 0:64], [8:16, 64:128],依此类推。另一种实现此目的的方法是获取输入数组 x 并按行主序遍历 x.reshape(128 // 8, 128 // 64, 8, 64).transpose(0, 2, 1, 3)

  • plgpu.SwizzleTransform(swizzle_in_bytes),它按照 PTX 文档CUDA 文档中的描述转换数据。抖动很有用,因为它允许在寄存器和共享内存之间以与 MMA 相关的布局传输数据而不会发生银行冲突。抖动后的内存实际外观的精确细节并不重要,因为所有基元都会自动考虑它。请注意,抖动量以字节为单位指定(仅支持 128、64、32 和 16),并且通常伴随 TileTransform(它使用其形状中的元素!)。

  • plgpu.TransposeTransform(permutation),它在数组被线性化之前排列其维度。这主要很有用,因为它允许您在 GMEM-SMEM 复制期间更改布局(但请记住,更改次小/最后一个维度不受硬件支持)。

数组布局#

我们目前已为您定义了几种有用的布局

  • plgpu.Layout.WGMMA,这是 Hopper 代 TensorCore 期望 MMA 累加器或 16 位输入操作数在寄存器中拥有的布局。

  • plgpu.Layout.WGMMA_ROW,这是上述布局沿行减少后的布局。行的广播是免费的,并将产生具有 WGMMA 布局的值。

  • plgpu.Layout.WGMMA_COL,这是上述的类似操作,只是沿列而不是行减少。

  • plgpu.Layout.WG_STRIDED,其中值平均分配给构成 Pallas 线程的 128 个 CUDA 车道。连续元素(在向量化后)以轮询方式分配给车道。当不需要与 TensorCores 交互时,它非常简单且有效。

  • plgpu.Layout.WG_SPLAT,表示该值是恒定的。每个 CUDA 车道将持有一个包含该值的寄存器。您通常无需与此布局交互,因为它在创建常量值时隐式使用,并且始终隐式可转换为其他布局。

目前,在默认操作模式下,数组布局传播仅在向前方向上发生,并且对协调布局冲突的隐式支持很少:只有 splat 布局可以隐式转换为任何其他布局。例如,如果您尝试添加两个具有不同布局的数组,则降低将发出抱怨并失败。有一些非常有限的设施允许您在布局之间进行转换,我们通常建议将值存储到 SMEM 并以目标布局读回。

MMA (TensorCore)#

在本节中,我们将重点介绍 Pallas:MGPU 内核如何利用 TensorCore 单元。TensorCore 的编程接口在不同的 NVIDIA GPU 代之间发生显着变化,这就是为什么 Pallas:MGPU 中的最低级接口也不同。

每个 MMA 操作都与三个操作数相关联

  • 累加器 D,形状为 (M, N)

  • 左输入 A,形状为 (M, K)

  • 右输入 B,形状为 (K, N)。所有操作数必须具有相同的元素类型。

每次使用 MMA 都涉及几个步骤

  1. 分配累加器的空间(MMA 隐式执行 D += A @ B

  2. 准备 AB 操作数

  3. 发出操作

  4. 等待操作完成

  5. 读出结果

步骤 2-4. 通常在一个关于收缩维度(K)的循环中执行。

AB 操作数的内存空间#

通常最好通过 SMEM 传递 AB 操作数,在那里可以使用 plgpu.copy_gmem_to_smem 方便地加载它们。为了使这些操作数与 MMA 操作兼容,它们需要在分配时指定适当的平铺和抖动变换。对于所有当前支持的代,TensorCore 要求数据被平铺成行主序的 2D 图块,形状为 (8, swizzle_elems),其中 swizzle_elems 是通过将抖动除以元素类型的字节宽度派生的。目前支持的抖动是:128、64 和 32。较大的抖动更可取,因为它们可以提高 GMEM 到 SMEM 复制的性能。

def mma_transforms(shape_dtype: jax.ShapeDtypeStruct):
  assert len(shape_dtype.shape) == 2
  if shape_dtype.shape[0] % 8:
    raise ValueError("Number of rows must be divisible by 8")
  for swizzle_bytes in (128, 64, 32):
    swizzle_elems = swizzle_bytes // shape_dtype.dtype.itemsize
    if shape_dtype.shape[-1] % swizzle_elems == 0:
      return (plgpu.TilingTransform((8, swizzle_elems)),
              plgpu.SwizzleTransform(swizzle_bytes))
  raise ValueError("Failed to find transforms for the specified window type")

如果需要转换操作数,A 操作数可以通过不同的内存空间传递(依赖于架构,见下文)。B 操作数*必须*位于 SMEM 中。

转置操作数#

在对 16 位操作数执行 MMA 时,TensorCore 可以自动转置输入数据。例如,A 引用可以具有形状 (K, M),但必须在将其传递到 mma 函数之前进行转置。例如

assert acc_ref.shape == (M, N) and a_ref.shape == (K, M) and b_ref.shape == (K, N)
a_ref_t = plgpu.transpose_ref(a_ref, (1, 0))
assert a_ref_t.shape == (M, K)  # The shape expected by plgpu.wgmma
plgpu.wgmma(acc, a_ref_t, b_ref)

在这种情况下,对 B 引用也允许进行类似的转置操作。

Hopper (wgmma)#

在本节中,我们将介绍使用 Hopper 代 TensorCore 的基础知识,这些 TensorCore 在 PTX 中暴露为 wgmma.mma_async 指令

分配累加器#

在 Hopper 硬件架构中,累加器分配在寄存器中,但在 Pallas 中,它被建模为一个可变引用,因为每个 MMA 操作都就地累加。有两种分配累加器的方法。

要创建零初始化的累加器,您可以使用 pl.run_scopedplgpu.ACC((m, n), dtype) 类型。

def compute(acc_ref):
  ...
  return acc_ref[...]
output = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32))

解引用累加器引用,如 compute 函数的结尾所示,将隐式等待所有未完成的 WGMMA 操作。

如果您想用现有数组初始化它,您可以使用 pl.run_stateplgpu.ACC.init(init_array)

def compute(acc_ref):
  ...
  return # pl.run_state only returns the final value of the accumulator
output = pl.run_state(compute)(plgpu.ACC.init(init_array))

如果 pl.run_state 具有累加器操作数,它将在返回最终值之前隐式等待所有未完成的 WGMMA 操作。

准备 AB 操作数#

如上所述,我们建议通过共享内存传递 AB。在这种情况下,必须指定正确的平铺和抖动变换。

plgpu.wgmma 此外还允许通过寄存器传递 A(即不是 SMEM 引用,而是常规 JAX 数组)。然而,这种模式存在许多重大缺点,并且很难确保充分的同步以使其安全。

TODO:解释可以执行此操作的条件。

发出操作#

支持的 MMA 形状是:

  • M 可被 64 整除

  • N 可被 8 整除且不大于 256

  • Kswizzle 除以操作数的元素类型字节宽度后的倍数

目前支持的数据类型为:jnp.float32jnp.bfloat16jnp.float16。累加器 D 必须是 jnp.float32,除了 jnp.float16 输入,在这种情况下它也可以是 jnp.float16

等待操作完成#

每个 plgpu.wgmma 调用隐式地与所有之前的 plgpu.wgmma 调用同步,因此一旦控制权从它返回,我们保证除了最后发出的 WGMMA 之外,没有其他 WGMMA 仍在运行。因此,先前发出的 WGMMA 指令读取的任何 SMEM 区域都可以被重用。这对于将 WGMMA 与异步内存复制进行流水线处理尤为重要。

buffers = 3  # In reality you might want even more
assert a_smem.shape == (buffers, m, k)
assert b_smem.shape == (buffers, k, n)
assert acc_ref.shape == (m, n)

def fetch_a_b(ki, slot):
  a_slice = ... # Replace with the right M/K slice
  b_slice = ... # Replace with the right K/N slice
  plgpu.copy_gmem_to_smem(a_gmem.at[a_slice], a_smem.at[slot], a_loaded.at[slot])
  plgpu.copy_gmem_to_smem(b_gmem.at[b_slice], b_smem.at[slot], b_loaded.at[slot])

def loop_body(i, _):
  slot = jax.lax.rem(i, buffers)
  plgpu.barrier_wait(a_loaded.at[slot])
  plgpu.barrier_wait(b_loaded.at[slot])
  plgpu.wgmma(acc_ref, a_smem.at[slot], b_smem.at[slot])
  # We know that only the last issued WGMMA is running, so we can issue a async load in
  # into the other buffer
  load_i = i + buffers - 1
  load_slot = jax.lax.rem(load_i, buffers)
  @pl.when(jnp.logical_and(load_i >= buffers, load_i < num_steps))
  def _do_fetch():
    fetch_a_b(load_i, slot)
for slot in range(buffers):
  fetch_a_b(slot, slot)
jax.lax.fori_loop(0, num_steps, loop_body, None)

Blackwell (tcgen05)#

Blackwell 代对 TensorCore 子单元进行了重大重新设计。它现在与常规 warp 调度器明显分离,并且不再使用或甚至支持使用寄存器作为其操作数。取而代之的是,引入了一个名为张量内存(TMEM)的新内存空间。更重要的是,成对 SM 的 TensorCore 现在可以合并它们的资源并计算跨越这两个 SM 的更大 MMA 操作。我们将此称为“集体 MMA 操作”

分配累加器/使用 TMEM#

TMEM 引用可以与分配所有其他引用的方式相同——使用 pl.run_scoped

@functools.partial(pl.run_scoped, tmem_ref=plgpu.TMEM((128, 128), jnp.float32))
def barrier_scope(tmem_ref):
  ...

并非所有形状都可以分配到 TMEM 中。目前仅支持 2D 引用,并且行数(第一维的大小)必须是 128 或 64。

此外,如果数据类型具有小于 32 位的位宽,则必须声明分配是打包的(例如,将两个 16 位元素放入 TMEM 中的单个 32 位单元)还是未打包的(每个元素填充到 32 位)。MMA 累加器(fp32 或 fp16)从不打包,但如果左操作数通过 TMEM 传递,则它必须始终打包。

@functools.partial(pl.run_scoped,
                   acc_ref=plgpu.TMEM((128, 128), jnp.float16, packed=False),
                   lhs_ref=plgpu.TMEM((128, 128), jnp.float16, packed=True))
def barrier_scope(acc_ref, lhs_ref):
  plgpu.tcgen05_mma(acc_ref, lhs_ref, rhs_smem_ref, ...)
  ...

TMEM 的另一个有趣的复杂性是它上的所有操作都是异步的。因此,使用 Python 下标语法进行的读写(通常用于例如 SMEM)不允许用于 TMEM。

加载#

可以使用 plgpu.async_load_tmem 进行加载,并使用 plgpu.wait_load_tmem 进行等待。

smem_ref[...] = plgpu.async_load_tmem(tmem_ref)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(smem_ref, gmem_ref)
plgpu.wait_smem_to_gmem(0)
plgpu.wait_load_tmem()  # Wait for the read to fully complete before we overwrite tmem_ref again.

加载语义非常令人困惑,因为从加载返回的数组可以安全使用而无需任何额外的同步。但是,如果读取的 TMEM 区域再次被覆盖(例如,通过存储或 MMA 操作),则发出加载的线程必须首先调用 plgpu.wait_load_tmem() 以确保程序保持无竞争。

注意

使人能够理解这种看似因果关系断裂的行为(数据在完全从 TMEM 读取之前就已到达寄存器)的一种方法是考虑它可能是 PTX 编译器的一个限制和一个便利功能的交互作用。我们不知道是否属实,但至少它是有意义的。

便利功能是编译器可以可靠地跟踪 TMEM 加载产生的寄存器使用情况,并插入必要的最小延迟以确保数据在被使用之前从 TMEM 到达。读取操作被展开成许多指令,这意味着在开始消耗由第一个加载填充的寄存器之前,不必等待所有这些指令。这就是为什么我们不需要保护结果的使用。

限制是编译器无法可靠地对 TMEM 加载和存储执行别名分析,这就是为什么任何未被显式等待分隔的加载和存储都被认为可以安全地并发执行。另一种选择将不必要地低估真正不相关的加载和存储的性能。这就是为什么我们需要在再次重用 TMEM 之前显式等待。

存储#

相反,使用 plgpu.async_store_tmem 进行存储,并使用 plgpu.commit_tmem 进行等待。

plgpu.async_store_tmem(tmem_ref, smem_ref[...])
plgpu.commit_tmem()
smem_ref2[...] = plgpu.async_load_tmem(tmem_ref)  # Safe to read from tmem_ref now

准备 AB 操作数#

我们建议通过共享内存传递 AB。在这种情况下,必须指定正确的平铺和抖动变换A 操作数也可以作为 TMEM 引用传递,但它必须是打包的。

发出操作#

支持的**非集体** MMA 形状是:

  • M 是 64 或 128

  • N 可被 8 整除且不大于 512

  • K8 * swizzle 除以元素类型的位宽后的倍数

支持的 **集体 MMA** 形状是:

  • M 是 128 或 256(每个块占一半)

  • N 可被 8 整除且不大于 256(每个块不大于 128)

  • K8 * swizzle 除以元素类型的位宽后的倍数

目前支持的浮点数据类型为:jnp.bfloat16jnp.float16jnp.float8_e5m2jnp.float8_e4m3fn。累加器可以是 jnp.float32jnp.float16,除了 jnp.bfloat16 时,它必须是 jnp.float32

目前支持的整数数据类型是 jnp.int8,累加器为 jnp.int32

注意

根据我们的基准测试,以下是一些性能经验法则:

  • 非集体 MMA 应始终使用 M=128 且 N >= 128。

    • M=64 会导致显着性能下降。

    • N=64 会导致明显的性能下降,但不如 M=64 那么显着。

  • 集体 MMA 始终具有合理的速度,但不如非集体 MMA 快。

    • 集体 MMA 的最大好处不是更高的 TensorCore 吞吐量,而是能够跨 SM 共享数据,从而提高内核的算术强度。

  • 抖动和转置似乎不会对性能产生显着影响。

等待操作完成#

等待 plgpu.tcgen05_mma 调用的结果需要使用 Barrier。我们建议阅读 Barrier 的参考文档,特别是其 Blackwell 相关子节 以获取更多信息。

如果 barrier 直接传递给 plgpu.tcgen05_mma,完成对该 barrier 的等待将表明最终累加器已写入 TMEM。例如

@functools.partial(pl.run_scoped, barrier_ref=plgpu.Barrier(orders_tensor_core=True))
def barrier_scope(barrier_ref):
  plgpu.tcgen05_mma(acc_tmem, lhs_ref, rhs_ref, barrier_ref, accumulate=False)
  plgpu.barrier_wait(barrier_ref)
  # We can read the result now.
  result = plgpu.async_load_tmem(acc_tmem)
  ...

如果未向 plgpu.tcgen05_mma 提供 barrier,则只有在调用 plgpu.tcgen05_commit 时才会跟踪其完成。

@functools.partial(pl.run_scoped, barrier_ref=plgpu.Barrier(orders_tensor_core=True))
def barrier_scope(barrier_ref):
  plgpu.tcgen05_mma(acc_tmem, lhs_ref, rhs_ref, accumulate=False)
  plgpu.tcgen05_mma(acc_tmem, lhs_ref2, rhs_ref2)
  plgpu.tcgen05_commit(barrier_ref)
  plgpu.barrier_wait(barrier_ref)
  # We can read the result now. Both MMAs have completed.
  result = plgpu.async_load_tmem(acc_tmem)
  ...

集体 MMA#

Blackwell 代提供了一种新的 MMA 操作方式,其中集群中 2 个 SM 的 TensorCore 协同执行一个 MMA 操作。每个 SM 的 B 操作数会与其他 SM 共享。DA 操作数是每个 SM 的本地操作数,不共享。

A diagram showing the partitioning of operands in a collective MMA

这意味着要执行形状为 M、N 和 K 的集体 MMA,每个 Pallas 线程中的操作数的大小应分别为:A(M // 2, K)B(K, N // 2)D(累加器)为 (M // 2, N)。将两个累加器堆叠起来就可以恢复执行 MxNxK 矩阵乘法的结果。

为了简化 B 操作数的加载,可以使用 plgpu.copy_gmem_to_smem,并结合 collective_axespartitioned_axis 来指示集体轴上的两个 Pallas 线程加载相同的切片,但每个线程只会获得其中一半。与仅使用 collective_axes 的复制不同,它不使用 TMA 组播(因为每个线程加载不同的数据切片),但它可以简化索引逻辑。

plgpu.copy_gmem_to_smem(
    b_gmem,  # [K, N]
    b_smem,  # [K, N // 2]
    b_tma_barrier,
    collective_axes="x",
    partitioned_axis=1,
)

使用 core_map#

pl.pallas_call 适用于单个 Pallas 线程可以执行整个 CUDA 块计算的内核。 pl.core_map 函数放宽了这一限制,允许在一个块内使用多个线程(例如,用于 warp 特化)或跨块集群使用(例如,利用组播 TMA)。

pl.pallas_call 替换为 pl.core_mapplgpu.kernel#

让我们从一个简单的 Pallas 内核开始,该内核递增数组

@functools.partial(
  pl.pallas_call,
  grid=(2,),
  in_specs=[pl.BlockSpec(block_shape=(128,), index_map=lambda i: (i,))],
  out_specs=pl.BlockSpec(block_shape=(128,), index_map=lambda i: (i,)),
  out_shape=jax.ShapeDtypeStruct((256,), jnp.float32), # Total output shape
)
def run_kernel(x_ref, y_ref):
  # x_ref and y_ref are in SMEM!
  y_ref[...] = x_ref[...] + 1

x = jnp.arange(256, dtype=jnp.float32)
y = run_kernel(x)
np.testing.assert_array_equal(y, x + 1)

我们可以使用 pl.core_map 编写一个类似的内核。一个很大的区别是,与 pl.pallas_call 不同,不会自动插入任何 GMEM<->SMEM 复制。如果需要,您可以自己插入它们,或者使用 plgpu.emit_pipeline 辅助函数。我们建议阅读软件流水线指南

@pl.run_state
def run_kernel(refs):
  x_ref, y_ref = refs
  # Here, we're not in the kernel yet! pl.run_state simply changes the JAX
  # immutable arrays into mutable GMEM (not SMEM!) references.

  # Define the mesh: 2 CUDA blocks over 1 axis called "x"
  mesh = plgpu.Mesh(grid=(2,), grid_names=("x",))

  @pl.core_map(mesh)  # core_map executes the body
  def kernel_body():
    # Once we enter the pl.core_map scope, we are in the body of the kernel.
    block_slice = pl.ds(jax.lax.axis_index("x") * 128, 128)
    y_ref[block_slice] = x_ref[block_slice] + 1

x = jnp.arange(256, dtype=jnp.float32)
y_init = jnp.zeros_like(x)
_, y = run_kernel((x, y_init))
np.testing.assert_array_equal(y, x + 1)

虽然 pl.core_map 是一个强大的 API,但它也非常底层,并且几乎总是用在 pl.run_state(将 JAX 数组转换为 ref)或 pl.run_scoped(为 scratch ref 分配)之下。因此,我们还提供了一个便捷 API plgpu.kernel

@functools.partial(
    plgpu.kernel,
    out_shape=jax.ShapeDtypeStruct((256,), jnp.float32),
    grid=(2,),
    grid_names=("x",),
)
def run_kernel(x_ref, y_ref):
  # x_ref and y_ref are in GMEM!
  block_slice = pl.ds(jax.lax.axis_index("x") * 128, 128)
  y_ref[block_slice] = x_ref[block_slice] + 1

x = jnp.arange(256, dtype=jnp.float32)
y = run_kernel(x)  # No need to preallocate outputs as in pl.core_map.
np.testing.assert_array_equal(y, x + 1)

注意

jax.sharding.Mesh 类似,用于定义*单个 GPU 内*的计算拓扑,指定工作如何在 CUDA 块(grid)、块内的 Pallas 线程(num_threads)以及可能存在的 CUDA 块集群(cluster)之间分配。plgpu.Mesh 用于 pl.core_map。这类似于 jax.sharding.Mesh 如何定义 JAX 中*跨多个设备*的分布式计算拓扑。两者都涉及在定义的拓扑上执行 SPMD 程序。此外,您可以对 Pallas 线程和集群运行“集合操作”(例如,使用 plgpu.ClusterBarrier 或集体异步复制),这类似于 JAX 集合操作(psumall_gather 等)在 JAX Mesh 中的设备之间运行。两者都使用命名轴,并且 jax.lax.axis_index(axis_name) 可用于获取线程或块的坐标。

每个 CUDA 块使用多个 Pallas 线程#

下面是一个示例,其中同一块中的两个 Pallas 线程通过 barrier 同步,甚至通过 SMEM 交换数据。

x = jnp.arange(128, dtype=jnp.float32)

@functools.partial(
    plgpu.kernel,
    out_shape=x,
    scratch_shapes=dict(
        smem_ref=plgpu.SMEM(x.shape, x.dtype),
        barrier_ref=plgpu.Barrier(),
    ),
    num_threads=2,
    thread_name="pallas_thread",
)
def run_kernel(x_ref, y_ref, smem_ref, barrier_ref):
  thread_id = jax.lax.axis_index("pallas_thread")

  @pl.when(thread_id == 0)
  def producer_thread():
    smem_ref[...] = x_ref[...] + 1
    plgpu.barrier_arrive(barrier_ref)  # Signal the consumer thread

  @pl.when(thread_id == 1)
  def consumer_thread():
    plgpu.barrier_wait(barrier_ref)  # Wait for the producer thread
    out_ref[...] = smem_ref[...] + 1

y = run_kernel(x)  # There's no need to preallocate the input anymore.
np.testing.assert_array_equal(y, x + 2)

虽然这个例子很简单,但在同步部分可以找到一个更复杂的例子。

多个线程经常用于高性能内核,例如最新的 flash attention 变体或 ping-pong 矩阵乘法。在这两者中,程序中有 2 个计算线程,它们交替使用 SM 的 ALU 和 TensorCore,以确保没有执行冲突。

另一种常用技术是分配一个 Pallas 线程,并将其完全用于调度由其他线程消耗的数据的异步复制。虽然从头开始实现这种方案可能很复杂,但我们提供了一个便捷的辅助 API:plgpu.emit_pipeline_warp_specialized

使用 CUDA 块集群#

下面的内核启动了一个由 2 个 CUDA 块组成的单个集群,并使用 TMA 组播功能为两个块集体地将 GMEM 复制到 SMEM。参与集体复制的所有块都必须调度完全相同的复制才能使程序有效。

@functools.partial(
    plgpu.kernel,
    out_shape=jax.ShapeDtypeStruct((2, 128), jnp.float32),
    scratch_shapes=dict(
        smem_ref=plgpu.SMEM((128,), jnp.float32),
        barrier_ref=plgpu.Barrier(),
    ),
    cluster=(2,),
    cluster_names=("cluster",),
)
def run_kernel(x_ref, y_ref, smem_ref, barrier_ref):
  # Specifying collective_axes will enable TMA multicast automatically.
  plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref, collective_axes="cluster")
  plgpu.barrier_wait(barrier_ref)
  plgpu.copy_smem_to_gmem(smem_ref, o_ref.at[jax.lax.axis_index("cluster")])
  plgpu.wait_smem_to_gmem(0)

x = jnp.arange(128, dtype=jnp.float32)
y = run_kernel(x)
# Each block gets the same data and writes it out.
np.testing.assert_array_equal(y, jnp.stack([x, x], axis=0))

pl.run_scoped 中的集体分配#

当使用具有多个 Pallas 线程(即 plgpu.Mesh 中的 num_threads > 1)的 pl.core_map 时,通过 pl.run_scoped(用于 SMEM 或 Barriers)进行的分配*必须由所有线程集体执行*。这是通过为 run_scoped 指定 collective_axis 参数来指示的,该参数有两个效果:

  1. 它保证所有线程将调用相同的分配,并且

  2. 所有线程将获得完全相同的分配。

如果未指定 collective_axes 或 collective_axes 不包含 Pallas 线程轴,则每个线程将获得其自己的私有 scratch 变量副本。这通常是不期望的,并且目前不支持。

使用 pl.get_global 进行全局(整个网格)分配#

有时,将信号量分配到可以与所有并行程序实例共享的方式很有用。例如,当并行实例的数量足够少以至于内核是持久的时。此类分配可以使用 pl.get_global 进行。

def body(out_ref):
  sem_ref = pl.get_global(plgpu.SemaphoreType.REGULAR)
  block_id = lax.axis_index("x")
  @pl.when(block_id == 0)
  def _():
    pl.semaphore_signal(sem_ref)  # Block 0 signals
  @pl.when(block_id == 1)
  def _():
    pl.semaphore_wait(sem_ref)  # Block 1 waits
    out_ref[...] = jnp.ones_like(out_ref)

out_shape = jax.ShapeDtypeStruct((128,), jnp.float32)
plgpu.kernel(body, out_shape=out_shape, grid=(2,), grid_names=("x",))()

同步结构和原语#

在本节中,我们将介绍用于线程同步以及一些异步操作的最重要的函数和数据结构。

commit_smem#

对引用的常规读/写保证产生与顺序程序顺序一致的值。例如,在下面的程序中,保证 value 等于 value2

ref[...] = value
value2 = ref[...]

然而,此保证不适用于异步原语,如异步复制或 MMA 操作。要使 SMEM 写入对这些原语可见,您需要使用 plgpu.commit_smem() 函数显式同步。

例如

smem_ref[...] = value
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(smem_ref, ...)

smem_ref[...] = value
plgpu.commit_smem()
plgpu.wgmma(smem_ref, ...)

反之,这种显式同步也是必需的,例如

v = plgpu.load(smem_ref, ())
plgpu.commit_smem()
plgpu.copy_gmem_to_smem(..., smem_ref, ...)

未能调用此函数很可能会导致微妙的数据竞争,因为这些异步硬件单元会读取 SMEM 中的陈旧数据。不幸的是,此函数相对昂贵,因此我们依赖您,用户,将其插入到其必需的最少位置。

Barrier#

这本质上是一个 PTX mbarrier 类型数组的薄包装器,并作为引用传递。所有涉及 barrier 的函数都期望只获取一个 barrier 参数,因此如果引用包含多个 barrier,您必须显式提取一个,使用 barriers.at[index]Barriers 始终在 SMEM 中分配,因此开销相对较低。每个 barrier 可以配置为在固定数量的“到达”后完成(默认为 1)。

要阻止线程直到 barrier 完成,请使用以下函数:

plgpu.barrier_wait(barrier)

警告

确保同步方案使得两个 barrier 完成之间不会发生两次 barrier 完成而没有调用 plgpu.barrier_wait,这一点至关重要。例如,如果您使用 Barriers 来同步两个生产者/消费者线程,您需要进行双向 barrier 同步以引入“反压”,从而阻止一个线程在另一个线程有机会等待之前到达两次。未能满足此要求将损坏数据结构并可能导致意外的失败(包括 CUDA 运行时错误)。请参阅下面的示例,了解一个有效的具有两个线程的程序。

警告

另一个关键限制是,在 barrier 的整个生命周期中,barrier 完成的数量必须等于 barrier 等待的数量。不允许在 barrier 的作用域分配结束时它有未等待的完成。否则,当它被编译器重用时,将其置于此状态可能会导致后续问题。

警告

最后,确保每个等待 Barrier 的线程都参与其所有 wait 操作至关重要。不允许例如从一个线程等待 barrier 的每次偶数完成,而从另一个线程等待所有其他完成。这样做会导致死锁。总结:当 Barrier 在某个线程中用于等待时,它必须观察该 barrier 的所有完成(通过等待它)。

请注意,Barrier 可以从任何源接收到达,没有限制。

完成 barrier 的操作有三种:

异步 GMEM 到 SMEM 复制#

当 TMA 引擎执行异步 GMEM 到 SMEM 复制时,它会将进度更新发布到 plgpu.copy_gmem_to_smem 中提供的 barrier。复制完成后,barrier 也会完成一次到达。

显式到达(跨线程同步)#

任何线程都可以使用以下函数显式到达 barrier:

plgpu.barrier_arrive(barrier)

这在同步两个生产者/消费者角色的线程时特别有用。在这种情况下,我们建议分配两个 Barrier 数组,其大小等于用于在两个线程之间传递数据的“队列”的大小。例如,假设一个线程继续将数组的切片写入 SMEM,而另一个线程从中读取。我们对 SMEM 区域进行三缓冲,以实现两个线程之间更多的异步。

tid = jax.lax.axis_index("thread")
assert queue.shape == (buffering, *item_shape)
assert produced.shape == consumed.shape == (buffering,)

def thread0_body(i, _):
  slot = jax.lax.rem(i, buffering)
  @pl.when(i >= buffering)
  def _await_consumed():
    plgpu.barrier_wait(consumed.at[slot])  # Wait for consumption of the value before overwriting it
  # Option 1: Compute the next value
  queue[slot] = produce()
  plgpu.barrier_arrive(produced.at[slot])  # Signal the value is ready
  # Option 2: Produce the value through async_copy
  # plgpu.copy_gmem_to_smem(..., queue.at[slot], barrier=produced.at[slot])
pl.when(tid == 0)(lambda: jax.lax.fori_loop(0, steps, thread0_body, None))

def thread1_body(i, _):
  slot = jax.lax.rem(i, buffering)
  plgpu.barrier_wait(produced.at[slot])  # Wait for the value to be ready
  consume(queue[slot])  # Load and compute
  plgpu.barrier_arrive(consumed.at[slot])  # Signal that the value is consumed
pl.when(tid == 1)(lambda: jax.lax.fori_loop(0, steps, thread1_body, None))

等待 tcgen05 TensorCore 指令#

在我们开始之前,一个重要的警告:

警告

在 Blackwell 代 GPU 上,Barrier 操作默认对 TensorCore 操作具有宽松的语义。这意味着默认情况下,任何与 TensorCore 相关的操作(包括 TMEM 操作)都可以被编译器移动到*barrier 信号之后*。类似地,任何与 TensorCore 相关的操作都可以移动到*barrier 等待之前*。

如果您打算使用 Barriers 来向其他线程指示 TensorCore 操作已完成,请使用 orders_tensor_core=True 分配 barrier。此参数将插入必要的指令以防止上述问题重排。

与旧 GPU 不同,观察 Blackwell 代 TensorCore 指令完成的唯一方法是将 Barrier 引用传递给 plgpu.tcgen05_mma 函数。MMA 完成后,TensorCore 将在 barrier 上到达。

请注意,此 Barriers 的用法要求它们使用 orders_tensor_core=True 创建,因为它们用于与 TensorCore 操作同步。

@functools.partial(pl.run_scoped, barrier_ref=plgpu.Barrier(orders_tensor_core=True))
def barrier_scope(barrier_ref):
  plgpu.tcgen05_mma(acc_tmem, lhs_ref, rhs_ref, barrier_ref, accumulate=False)
  plgpu.barrier_wait(barrier_ref)
  # We can read the result now
  result = plgpu.async_load_tmem(acc_tmem)
  ...

ClusterBarrier#

ClusterBarrierBarrier 非常相似,只是用于跨块集群进行同步,而不是在单个块内的线程之间进行同步。当集群中的块在共享资源上协作时,这始终是必要的。下面我们概述一些 ClusterBarrier 对于确保正确性至关重要的更常见情况。

重复使用 SMEM 进行集体异步复制#

在以下示例中,ClusterBarrier 确保两个块都完成了对 x_smem 的使用,然后才被覆盖。没有 barrier,一个块就可以超前运行,并在另一个块完成读取 x_smem 之前通过进入集体复制来开始覆盖它。

def collective_smem_reuse(x_gmem, x_gmem2, y_gmem, x_smem, local_barrier, cluster_barrier):
  plgpu.copy_gmem_to_smem(x_gmem, x_smem, local_barrier, collective_axes="cluster")
  plgpu.barrier_wait(local_barrier)  # x_smem is ready to be used once the local wait completes
  y_gmem[0] = x_smem[...]
  plgpu.barrier_arrive(cluster_barrier)
  plgpu.barrier_wait(cluster_barrier)  # x_smem can only be reused once the cluster barrier completes
  plgpu.copy_gmem_to_smem(x_gmem2, x_smem, local_barrier, collective_axes="cluster")
  plgpu.barrier_wait(local_barrier)  # x_smem is ready to be used once the local wait completes
  y_gmem[1] = x_smem[...]

在 Blackwell 上重复使用 TMEM 进行集体 MMA#

此示例的工作方式与上一个非常相似,只是这次 TMEM 是共享资源。一个块为它们两个发出集体 MMA,但它们都需要安全地完成从 TMEM 的读取,然后才能将其重用于另一个集体 MMA。

def collective_tmem_reuse(acc_tmem, lhs_ref, rhs_ref, mma_barrier, cluster_barrier):
  leader_block = lax.axis_index("cluster") == 0
  @pl.when(leader_block)
  def _do_mma():
    plgpu.tcgen05_mma(
        acc_tmem, lhs_ref.at[0], rhs_ref.at[0], mma_barrier,
        accumulate=False, collective_axis="x",
    )
  plgpu.barrier_wait(mma_barrier)
  do_something(plgpu.async_load_tmem(acc_tmem))
  plgpu.wait_load_tmem()  # Ensure the load is complete.
  plgpu.barrier_arrive(cluster_barrier)
  plgpu.barrier_wait(cluster_barrier)  # acc_tmem can only be reused once the cluster barrier completes
  @pl.when(leader_block)
  def _do_mma():
    plgpu.tcgen05_mma(
        acc_tmem, lhs_ref.at[1], rhs_ref.at[1], mma_barrier,
        accumulate=False, collective_axis="x",
    )
  ...

Semaphore#

信号量是强大的同步结构,主要用于跨不同的块进行同步,可能在不同的设备上运行。对于单个块内的线程同步,首选使用 Barriers,而对于集群同步,首选使用 ClusterBarriers。信号量实现为位于 GMEM 中的 32 位原子计数器,支持以下操作:

  • pl.semaphore_signal,它原子地增加信号量。信号量在目标设备上可见之前,线程在信号之前执行的任何操作(包括通过 NVLINK 对远程内存的读写)都保证完成。

  • pl.semaphore_wait,它会阻止线程直到信号量达到*至少*所需值,此时该值会被原子地递减,并且线程会被唤醒。该函数可以选择性地调用 decrement=False,它将在值至少达到请求值时唤醒线程,但信号量的值不会被减少。非递减版本效率稍高。

这里我们展示了一个交换两个小分片于两个设备之间的小示例内核。

def exchange_shards(x_ref, y_ref, done_sem):
  other_dev_id = 1 - lax.axis_index("x")  # We assume two devices
  neighbor_ref = plgpu.remote_ref(y_ref, other_dev_id)
  neighbor_ref[...] = x_ref[...]  # This will write over NVLINK
  pl.semaphore_signal(done_sem, device_id=other_dev_id)  # Signal that the write is complete
  pl.semaphore_wait(done_sem)  # Wait for the other device to write to our memory

mesh = jax.make_mesh((2,), ("x",))
y = jax.jit(
    jax.shard_map(
      lambda x: plgpu.kernel(exchange_shards, out_shape=x,
                             scratch_shapes=[plgpu.Semaphore.REGULAR])(x),
      mesh=mesh, in_specs=P("x"), out_specs=P("x"), check_vma=False,
    )
)(x)

集群启动控制#

集群启动控制是 Blackwell GPU(SM100A+)中引入的一项功能,支持工作窃取或 CUDA 网格的动态调度。这允许已完成其工作的 SM(或 SM 集群)取消为另一个 SM 预定的块的启动,并为自己执行工作。最终结果是改进了 SM 之间的负载均衡,并且您应该会看到 GPU 在内核的后期得到更好的利用。Mosaic GPU 暴露了低级集群启动控制命令以及一个抽象大部分实现细节的辅助 API。

直接使用集群启动控制 API#

Mosaic GPU 直接公开了两个低级集群启动控制 API 函数: plgpu.try_cluster_cancelplgpu.query_cluster_canceltry_cluster_cancel 是一个异步操作,它将原子地尝试取消可用块的启动,并将结果放入一个 Ref 中。结果 Ref 应该是通过 plgpu.TryClusterCancelResult() 分配的一个 scratch Ref(底层是一个 16 字节的 SMEM Ref)。 query_cluster_cancel 将读取结果并返回两个值:一个包含请求的网格轴索引的元组,以及一个指示取消是否成功的布尔值。如果 query_cluster_cancel 不成功,则网格索引的结果是未定义的,不应使用。

当与集群一起使用时,同一集群内的所有块将从 query_cluster_cancel 获得相同的结果。

以下示例演示了如何使用这些函数调用内核

@functools.partial(
    plgpu.kernel,
    grid=grid,
    grid_names=grid_names,
    scratch_shapes=dict(
        result_ref=plgpu.TryCancelResultRef(),
        barrier_ref=plgpu.Barrier()
    )
)
def kernel(result_ref, barrier_ref):
  plgpu.try_cluster_cancel(result_ref, barrier_ref)
  # ... do work
  plgpu.barrier_wait(barrier_ref)
  grid_idxs, success = plgpu.query_cluster_cancel(result_ref, grid_names)

警告

重要的是要确保集群中的所有线程都进行适当的同步。在大多数情况下,当取消多个块时,您可能需要双缓冲结果和屏障,以确保不会发生竞态条件。因此,我们建议使用 plgpu.dynamic_scheduling_loop 辅助函数。

使用 plgpu.dynamic_scheduling_loop 辅助函数#

使用动态工作调度的一个常见模式是持续轮询并在内核体内执行工作,直到没有更多工作为止,然后退出内核。 plgpu.dynamic_scheduling_loop 辅助函数实现了这个精确的模式。

@plgpu.dynamic_scheduling_loop(
  grid_names=grid_names,
  thread_axis=thread_name  # Required if using multiple threads in a kernel.
)
def body(loop_info):
  grid_indices = loop_info.index
  # ... do work

使用此模式时,应使用等于待处理工作逻辑总量的网格来实例化内核(而不是将网格设置为核心数量的持久性内核)。运行此循环的每个核心将持续查询下一个可用的工作块,当整个网格已调度时,循环将终止。主体函数的签名与 plgpu.nd_loop(用于普通持久性内核)中使用的签名相同,并且接收一个 loop_info 数据类,该数据类包含迭代信息,并且可以选择支持携带值。

异步复制#

现代 GPU 可以直接异步地在 GMEM 和 SMEM 之间复制数据,而无需寄存器。从 Hopper 代开始,复制甚至可以卸载到称为 Tensor Memory Accelerator (TMA) 的特殊硬件单元,Mosaic 使用它来实现这些功能。

GMEM 到 SMEM 的复制#

要调度异步 GMEM 到 SMEM 的复制,请使用 plgpu.copy_gmem_to_smem。该函数接受三个操作数:源引用、目标引用和一个 Barrier。复制完成后,将在屏障上观察到一个到达,就好像后台线程调用了 plgpu.barrier_arrive(barrier) 一样。

def body(in_gmem_ref, out_gmem_ref, smem_ref, barrier):
  plgpu.copy_gmem_to_smem(in_gmem_ref, smem_ref, barrier)
  plgpu.barrier_wait(barrier)
  ...

plgpu.kernel(
  body,
  out_shape=...,
  scratch_shapes=[plgpu.SMEM(x.shape, x.dtype), plgpu.Barrier()],
)

一个屏障可以用于同步多个复制,但它必须以更高的 arrival_count 分配。

def body(in_gmem_ref, in_gmem_ref2, out_gmem_ref, smem_ref, smem_ref2, barrier):
  plgpu.copy_gmem_to_smem(in_gmem_ref, smem_ref, barrier)
  plgpu.copy_gmem_to_smem(in_gmem_ref2, smem_ref2, barrier)
  plgpu.barrier_wait(barrier)  # Awaits both copies
  ...

plgpu.kernel(
  body,
  out_shape=...,
  # Barrier is allocated with 2 arrivals.
  scratch_shapes=[plgpu.SMEM(x.shape, x.dtype), plgpu.Barrier(num_arrivals=2)],
)

集体复制#

当使用块集群时,异步传输具有 *组播* 选项,这意味着集群中的多个块可以集体加载相同的输入。在某种程度上,这可以被看作是对所有参与块的 L2 命中保证,因为它允许更好地共享有限的 HBM 带宽。

警告

使用集体复制时,沿指定集群轴的所有块都必须发出相同的集体复制,程序才有效。不允许仅从一个块发出,而不从其他块发出,这将导致未定义行为(很可能是死锁)。

警告

使用集体复制时,您需要格外小心重复使用 SMEM 缓冲区。集群中的不同块可能会在不同的时间完成使用它们,但第一个发出下一个集体复制的块可能会覆盖其他块仍在使用的数据。有关如何安全地执行此操作的示例,请参阅 ClusterBarrier 部分

def body(in_gmem_ref, in_gmem_ref2, out_gmem_ref, smem_ref, smem_ref2, barrier):
  block_id = lax.axis_index("cluster")
  # Both blocks in the cluster load the same data into smem_ref, so we can use
  # a collective copy here.
  plgpu.copy_gmem_to_smem(in_gmem_ref, smem_ref, barrier, collective_axes="cluster")
  # Each block in the cluster loads a different slice of in_gmem_ref2, so we
  # are not allowed to use collective copies.
  plgpu.copy_gmem_to_smem(in_gmem_ref2.at[block_id], smem_ref2, barrier)
  plgpu.barrier_wait(barrier)  # Awaits both copies
  ...

plgpu.kernel(
  body,
  out_shape=...,
  # Barrier is allocated with 2 arrivals.
  scratch_shapes=[plgpu.SMEM(x.shape, x.dtype), plgpu.Barrier(num_arrivals=2)],
)

集体分区复制(仅限 Blackwell)#

在 Blackwell 代中,涉及两个块集群的集体复制可以通过传递额外的 partitioned_axis 参数进行 *分区*。指定后,GMEM 引用在指定维度上的大小预计是目标 SMEM 引用的两倍。第一个块中的目标将被 GMEM 引用的第一部分覆盖,而第二个块将接收第二部分。

这本身等同于在不同的输入切片上执行两次非集体复制,但有一个关键区别:只有第一个块中的屏障将在两个复制完成后接收到达。第二个块中的屏障参数将被忽略,第二个块不能使用它来等待传输完成。

可以说,这是一个有点令人惊讶的功能,但在 Blackwell 上的集体 MMA 的背景下,它是合理的。在那里,每个块负责将操作数加载到 SMEM 中,但只有第一个块等待传输完成并发出 MMA 指令。第二个块通常等待 MMA 的完成来指示传输已完成,并且 SMEM 数据已被读出,这意味着它可以安全地覆盖它。

SMEM 到 GMEM 的复制#

要调度异步 GMEM 到 SMEM 的复制,请使用 plgpu.copy_smem_to_gmem。与另一个方向相反,此原语仅接受源和目标引用。要等待复制完成,请使用 plgpu.wait_smem_to_gmem

SMEM 到 GMEM 复制的同步方案有点出乎意料,因为它们不能以任意顺序等待。 plgpu.wait_smem_to_gmem 接受一个参数,即 **您不想等待** 的最近复制次数,或者等同于您仍希望运行的异步 SMEM 到 GMEM 复制的次数。

def copy_out(x_smem, y_smem, x_gmem, y_gmem):
  plgpu.copy_smem_to_gmem(x_smem, x_gmem)
  plgpu.copy_smem_to_gmem(y_smem, y_gmem)
  plgpu.wait_smem_to_gmem(1, wait_read_only=True)
  # At this point we know that the data of x_smem has been read, but we don't
  # yet know that x_gmem contains the updated data.
  plgpu.wait_smem_to_gmem(1)
  # At this point we know that the x_smem -> x_gmem copy is done, but we know
  # nothing about the y_smem -> y_gmem copy.
  plgpu.wait_smem_to_gmem(0)
  # At this point we know that both copies are complete.

请注意,SMEM 到 GMEM 的复制只能在发出它的同一线程中等待。如果尚未发出复制或所有复制已完成, wait_smem_to_gmem 将立即返回。

仅等待从 SMEM 读取#

另一个选择是,您可以等待复制被提交到 GMEM。您可以选择等待复制完全写入 GMEM(以一种对后续读取可见的方式),或者您可以通过在等待函数中指定 wait_read_only 来仅等待数据从 SMEM 读取。如果您还不打算将发送到 GMEM 的数据读回,这可以更快地重用 SMEM 缓冲区。

分组多个复制#

copy_smem_to_gmem 收到 commit_group=False 作为参数时,它不能被等待,直到显式调用 plgpu.commit_group,或者发出另一个没有该参数的 copy_smem_to_gmem。自上次提交以来的所有 SMEM 到 GMEM 复制被分组为一个可等待单元。

def copy_out(x_smem, y_smem, x_gmem, y_gmem):
  plgpu.copy_smem_to_gmem(x_smem, x_gmem, commit_group=False)
  plgpu.copy_smem_to_gmem(y_smem, y_gmem)  # Implicitly commits both copies
  plgpu.wait_smem_to_gmem(1)
  # At this point we only know that no SMEM to GMEM copies other than the two
  # above are active.
  plgpu.wait_smem_to_gmem(0)
  # Only now we know that both copies above have completed.

异步收集#

在 Blackwell GPU 上,TMA 引擎有一个额外的模式,允许高效地收集 2D 矩阵的第一个维度。使用此模式实际上非常简单。索引的一维数组应加载到 plgpu.Layout.TMA_GATHER_INDICES 布局中,并且源 GMEM 引用必须使用 `.at` 运算符与该数组索引。

@functools.partial(
    self.pallas_call,
    out_shape=jax.ShapeDtypeStruct(out_shape, dtype),
    out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM, transforms=transforms),
    in_specs=(
        pl.BlockSpec(memory_space=plgpu.GMEM),
        pl.BlockSpec(memory_space=plgpu.SMEM),
    ),
    scratch_shapes=[plgpu.Barrier()],
)
def kernel(x_ref_gmem, idx_ref, o_ref, barrier_ref):
  idxs = plgpu.load(idx_ref, (), layout=plgpu.Layout.TMA_GATHER_INDICES)
  plgpu.copy_gmem_to_smem(x_ref_gmem.at[idxs], o_ref, barrier_ref)
  plgpu.barrier_wait(barrier_ref)

plgpu.copy_gmem_to_smem 会自动识别引用已被数组切片,并将使用收集 TMA 指令来实现复制。

内联 Mosaic GPU#

待办事项

编译器参数#

待办事项

调试#

Mosaic GPU 公开了一些环境变量来诊断生成的低级代码的问题。

  • 当设置时, MOSAIC_GPU_DUMP_PTXAS 允许将 ptxas 的编译日志转储到标准输出;

  • 当设置时, MOSAIC_GPU_DUMP_PTX 允许将编译期间生成的 PTX 代码转储到标准输出;

  • MOSAIC_GPU_DUMP_MLIR_PASSES 允许在编译管道中的每个 MLIR 传递后将 IR 转储到标准输出;

  • 当设置时, MOSAIC_GPU_DUMP_SASS 允许将编译结束时产生的 SASS 代码转储到标准输出;

  • MOSAIC_GPU_DUMP_SASS_CTRL 允许将遵循 NervanaSystems/maxas 的 SASS 控制代码转储到标准输出;

  • MOSAIC_GPU_DUMP_TO 允许指定一个必须存在的目录路径,其中所有上述内容都将作为文件转储。

  • MOSAIC_GPU_LLVM_DEBUG_ONLY 允许指定一个逗号分隔的 LLVM 调试类型 列表,以生成相关的 LLVM 调试日志。此环境变量仅在调试构建(即没有 NDEBUG 的构建)中可用。

  • MOSAIC_GPU_DUMP_LLVM 允许在设置时转储 LLVM IR。它等同于设置 MOSAIC_GPU_LLVM_DEBUG_ONLY=serialize-to-llvm,并且两个环境变量可以组合使用。与 MOSAIC_GPU_LLVM_DEBUG_ONLY 一样,此环境变量仅在调试构建中可用。

从 PyTorch 调用内核#

plgpu.as_torch_kernel 装饰器包装了一个 Pallas:MGPU 内核,允许使用 PyTorch 张量调用它。它接受 CUDA 张量作为输入,并返回在同一设备上新分配的 CUDA 张量。

示例

import functools
import jax
import jax.numpy as jnp
import torch

@functools.partial(
    pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32)
)
def add_kernel(x_ref, y_ref, o_ref):
  o_ref[...] = x_ref[...] + y_ref[...]

x = torch.arange(128, dtype=torch.int32, device="cuda")
y = x * x
out = plgpu.as_torch_kernel(add_kernel)(x, y)

plgpu.as_torch_kernel 仅支持包含单个内核调用的函数(例如,通过 pl.pallas_callplgpu.kernel),并且不支持调用其他 JAX 操作,例如来自 jax.numpy 的操作。