使用 Pallas 编写 Mosaic GPU 内核#

此页面是 Pallas:MGPU 后端最重要功能的参考。它不是教程,因此我们不期望每个人都从头到尾阅读它。尽管如此,还是值得浏览一下,以便熟悉您可以在其他教程中找到的一些模式。

在以下示例中,我们假设以下导入在作用域内

import jax.experimental.pallas as pl
import jax.experimental.pallas.mosaic_gpu as plgpu

什么是 GPU?#

从技术上讲,NVIDIA GPU 架构如下所示:GPU 被划分为流式多处理器 (SM)。这在 CUDA 编程模型中的体现是,每个 CUDA 线程块 (或 CTA) 都被调度到恰好一个 SM 上,但多个块可以同时调度到一个 SM 上。

每个 SM 包含一块称为共享内存 (SMEM) 的快速内存和 4 个细分,每个细分包含一个warp 调度器和计算单元(ALU、TensorCore、...)。这也反映在 CUDA 程序中:每个 warp(块中连续 32 个 CUDA 线程的组)以循环方式分配给这些细分之一。与块类似,每个 warp 被分配给恰好一个细分(它永远不会迁移),但多个 warp 可以被分配到同一个 SM 细分。在每个时钟周期,来自每个细分的 warp 调度器尝试选择其驻留 warp 之一来执行下一条指令。

A diagram of one NVIDIA SM

更进一步,最近的 CUDA 版本还概述了 warpgroup 的概念,它是 4 个连续的 warp。了解硬件的结构后,我们可以看到它的来源:4 个连续的 warp 占据了 SM 的 4 个象限,并让我们发出利用整个 SM 的指令。

GPU 可以从许多不同的角度来看待,在这里我们想专注于一个略微简化的模型,该模型以 TensorCore 为中心。这应该有助于您驾驭编写涉及 TensorCore 的内核的复杂性,但请记住,真实情况更加复杂。

对于我们的目的而言,TensorCore 操作已经变得如此之大,以至于继续遵循 CUDA 模型不再有太大意义。因此,对我们来说,GPU 是单线程内核(SM)的集合,其中 Pallas:MGPU 的一个线程对应于一个 CUDA warpgroup。在这个模型中,您在内核中执行的每个操作都占用整个 CUDA warpgroup,并且其组成的 warp 始终以锁定步骤运行(考虑到硬件调度的抖动),并且永远不会通过控制流采取不同的路径(除了我们稍后将讨论的 core_map 之外的极小例外)。这里一个值得注意的补充是,我们仍然允许您在同一个 SM 上共同调度多个 Pallas 级别的线程,以便它们可以通过共享内存进行协作和通信(我们通过将它们放在同一个 CUDA 块中来实现)。

从现在开始,每当我们说“线程”时,我们指的是 Pallas 线程,而不是 CUDA 线程/lane。

这与 Triton 推广的编程模型非常相似,但正如您将看到的,存在一些差异。Mosaic GPU 往往更底层,这通常意味着您将不得不投入更多工作,但它也让您更多地掌控。在我们看来,这两种方法都有其优点,我们鼓励您选择最适合您需求的后端!Pallas 支持并将继续支持 Triton 作为备用 GPU 后端。

顺序执行 & 使用多个硬件单元#

与更复杂的 CPU 架构不同,GPU 仅支持顺序执行。然而,这并不意味着在任何给定时间只有一条指令在运行!每个 SM 象限都有多个独立的 functional units:TensorCore、算术逻辑单元 (ALU)、加载/存储 (LSU)、特殊功能单元 (SFU)。如果第一条指令针对其中一个单元,并且后面跟着另一条指令(不使用第一条指令的结果),那么 warp 调度器可以在第一条指令完成之前发出第二条指令。这通常被称为指令级并行性 (ILP),并且是现代 TensorCore 内核中的一个常见主题:TensorCore 操作非常大,并且需要很多周期才能完成,因此如果不在此时尝试使用其他单元,那将是一种浪费。

为了进一步扩展这一点,我们可以通过允许多个 Pallas 线程并发运行来利用这种硬件单元级别的并行性。如果一个线程主要占用 ALU,而另一个线程主要发出 TensorCore 相关指令,我们可以利用内置于 warp 调度器中的高效上下文切换来保持两个单元都处于忙碌状态。这是 FlashAttention 3CUTLASS ping-pong matmul kernels 等算法背后的核心思想之一。

有关 warp 调度和指令发出如何工作的更多信息,我们建议阅读 Analyzing Modern NVIDIA GPU cores

内存空间#

GPU 具有几个不同的内存空间,这些空间可以根据容量大小和速度(总带宽和单次访问延迟)进行完全排序。

A diagram of memory spaces of an NVIDIA GPU

最大的内存空间是 plgpu.GMEM,代表全局内存。在最近的数据中心级 GPU 中,此内存空间通常以数十甚至数百 GB 为单位衡量,但它也是最慢的。

下一个内存空间,用于 L2 缓存,在某种程度上也是全局的,因为它由整个 GPU 共享,但其使用只能通过缓存提示间接影响。因此,没有办法手动将值放置在那里,因此 Pallas:MGPU 中未公开此内存空间。虽然只有大约 100MB 的大小,但此内存的带宽比 GMEM 高得多,因此在编写高性能内核时,仍然经常建议利用它。

接下来是共享内存,或 plgpu.SMEM。此内存直接位于每个 SM 内部,因此它是分区的。除非使用块集群(请参阅下面的集群部分),否则每个块仅允许访问其自己的 SMEM 分配。

最后,最低级别的内存空间是寄存器内存。这是 Pallas 内核中每个单个值(即 JAX 数组)将要驻留的位置。如果编译器用完寄存器来存储这些数组,它将插入spills,这意味着它将定期存储和重新加载值到内存。这些 spills 通常会引入其他显著的性能下降,因此我们建议避免它们。关于 spills 的警告消息可以在内核编译期间的 ptxas 消息中清楚地看到。要使它们可见,请在您的环境中运行 MOSAIC_GPU_DUMP_PTXAS=1

Blackwell GPU 世代有一个额外的内存空间,称为tensor memoryplgpu.TMEM。TMEM 与寄存器内存非常相似,只是它由您显式分配和管理。它用于存储 MMA 累加器、操作数元数据(用于稀疏性或缩放),以及可选的左 MMA 操作数。有关 TMEM 的更多信息,请参阅 Blackwell MMA 部分。

在特定内存空间中请求/分配内存#

内核输入或输出默认放置在 SMEM 中。如果您想将它们作为 GMEM 引用访问,请将 memory_space=plgpu.GMEM 添加到它们的 BlockSpec 中。如果您希望使用 GMEM 中的整个输入或输出数组来调用内核,则指定 BlockSpec(memory_space=plgpu.GMEM) 就足够了。

SMEMTMEM 可以在 pl.pallas_callscratch_shapes 参数中显式分配,或使用 pl.run_scoped 分配。要分配引用,只需使用请求的形状和 dtype 调用内存空间对象。例如:plgpu.SMEM((128, 128), jnp.float16) 将在共享内存中分配一个 128x128 的 float16 元素数组。

利用 L2 缓存#

虽然 L2 缓存无法手动管理,但与全局内存相比,其明显更高的带宽使其值得考虑。利用它的最简单方法是重新排序并行网格维度,以便在相似时间段内调度的调用也访问相同的输入数据。

虽然 CUDA 编程模型不保证关于块分配给 SM 的顺序的任何内容,但在最近的世代中,启发式方法似乎只是以列优先顺序迭代 (x, y, z) CUDA 网格(即 x 是变化最快的维度,而 z 是变化最慢的维度)。类似地,Pallas:MGPU 不保证用户指定的网格如何映射到 CUDA 网格(Pallas 支持任意秩的网格,而不仅仅是最多 3D)。但是,您可以假设迭代将以行优先顺序进行。也就是说,如果网格具有维度 (a, b),则 b 将是变化最快的维度,而 a 将是较慢的维度。

为了给出一个实际的例子,考虑一个普通的矩阵乘法内核。在那里,通常使用两个并行网格维度 (m, n),对应于平铺两个非收缩维度。如果我们使用这个简单的方案,在 Pallas:MGPU 中,所有程序 id 为 (0, (1, m=0 的程序必须读取所有的 B 操作数!如果 nk 维度非常大,那么我们就不可能从 (0, (1, Your browser does not support SVGs or scripting is disabled. This would be an image showing the access pattern of first 16 blocks without grid tiling.

但是,如果我们简单地将网格重新排列为 (m // mt, n, mt)(然后在内核中用 pl.program_id(0) * mt + pl.program_id(2) 替换 pl.program_id(0)),很容易看出沿两个维度的程序带将被并发调度(而不是调度单行)。这大大增加了并发程序的数量,这些程序加载相似的数据切片,通常显着提高了 L2 利用率,从而提高了内核的整体性能(如果它是内存受限的)。继续我们有 16 个块的示例并使用 mt=4,我们得到以下访问模式

Your browser does not support SVGs or scripting is disabled. This would be an image showing the access pattern of first 16 blocks with grid tiling.

请注意,即使活动块的数量没有改变,它们访问的数据的总 footprint 也减少了一半!我们现在获得 L2 命中的机会更高。

数组布局和内存引用转换#

在 Pallas 中,您使用的数据结构(数组和引用)具有逻辑形状(例如,128x128 矩阵)。此逻辑形状必须映射到物理表示(数据在 GPU 内存中的实际表示方式)。具体的映射取决于数据驻留的位置

  1. 数组布局: 数组存储在寄存器内存中,我们将此映射称为布局。布局定义了数组的元素如何在构成 Pallas 线程的 CUDA lane 可用的寄存器之间分布。

  2. 内存引用转换: 对于指向 SMEM 的可变引用,此映射称为转换。转换描述了逻辑数据结构如何在该内存块中排列。

这些概念对于性能至关重要,尤其是在与 TensorCore 等专用硬件单元交互或优化内存访问模式时。

我们正在开发一种模式,该模式将完全自动地处理分配布局和转换(尽管有提供提示和更多控制的方法)。下面列出的 API 可能会继续起作用,但将变为可选。

内存引用转换#

转换在首次分配内存引用时应用。对这些引用进行操作的 Pallas 原语将自动考虑与其关联的转换。

def body(..., scratch_ref):
  # Asynchronous copy will reformat the GMEM data to match the SMEM transforms
  plgpu.copy_gmem_to_smem(..., scratch_ref, barrier)
  barrier.wait()
  plgpu.wgmma(..., scratch_ref)  # wgmma only accepts properly transformed refs
  ...

引用分配有两种方式,每种方式都有一种选择所需转换的方法

1. 使用 GPUBlockSpec

transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128))
f = pl.pallas_call(
  in_specs=plgpu.GPUBlockSpec(in_block_shape, in_index_map, transforms=transforms),
  out_specs=plgpu.GPUBlockSpec(out_block_shape, out_index_map, transforms=transforms),
  ...
)

2. 在分配的 SMEM 上指定 transforms 参数

transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128))
f = pl.pallas_call(
  scratch_shapes=plgpu.SMEM((128, 128), jnp.float16, transforms=transforms),
  ...
)

可用的转换是

  • plgpu.TileTransform(tile_shape),它将数据组织成形状为 tile_shape 的连续的、非重叠的 tile。一个 tile 的数据始终完全线性化(行优先),然后再开始另一个 tile(tile 也以行优先顺序遍历)。例如,将 TileTransform((8, 64)) 应用于 (128, 128) 引用意味着对应于逻辑切片 [0:8, 0:64] 的数据将首先存储(行优先),然后是 [0:8, 64:128], [8:16, 0:64], [8:16, 64:128],依此类推。实现此目的的另一种方法是获取输入数组 x 并以行优先顺序遍历 x.reshape(128 // 8, 128 // 64, 8, 64).transpose(0, 2, 1, 3)

  • plgpu.SwizzleTransform(swizzle_in_bytes),它按照 PTX 文档CUDA 文档 中描述的方式转换数据。Swizzling 非常有用,因为它允许在寄存器和共享内存之间传输 MMA 相关布局中的数据,而不会发生 bank 冲突。在 swizzling 之后内存看起来像什么的确切细节并不那么重要,因为所有原语都会自动考虑它。请注意,swizzle 量以字节为单位指定(仅支持 128、64、32 和 16),并且通常伴随 TileTransform(在其形状中使用元素!)。

  • plgpu.TransposeTransform(permutation),它在数组线性化之前排列数组的维度。这主要有用之处在于,它允许您在 GMEM-SMEM 复制期间更改布局(请记住,硬件不支持更改最次要/最后一个维度)。

数组布局#

到目前为止,我们为您定义了一些有用的布局

  • plgpu.Layout.WGMMA,这是 Hopper 世代 TensorCore 期望 MMA 累加器或 16 位输入操作数在寄存器中拥有的布局。

  • plgpu.Layout.WGMMA_ROW,这是在沿行减少后获得的上述布局。重新广播行是免费的,并将生成具有 WGMMA 布局的值。

  • plgpu.Layout.WGMMA_COL,它是上述布局的类似物,只是沿列而不是行减少。

  • plgpu.Layout.WG_STRIDED,其中值在组成 Pallas 线程的 128 个 CUDA lane 之间平均分区。连续的元素(在向量化之后)以循环方式分配给 lane。当不需要与 TensorCore 交互时,非常简单有效。

  • plgpu.Layout.WG_SPLAT,指示该值是常数。每个 CUDA lane 将保存一个包含该值的寄存器。您通常永远不必与此布局交互,因为它在创建常量值时隐式使用,并且始终可以隐式转换为其他布局。

目前,在默认操作模式下,数组布局传播仅在向前方向发生,并且对协调布局冲突的隐式支持很少:只有 splat 布局可以隐式转换为任何其他布局。例如,如果您尝试添加两个具有不同布局的数组,则降低将抱怨并失败。有一些非常有限的工具可以让您在布局之间进行转换,我们通常建议将值存储到 SMEM 并以目标布局将其读回。

MMA (TensorCore)#

在本节中,我们重点介绍 Pallas:MGPU 内核如何利用 TensorCore 单元。TensorCore 的编程接口在不同的 NVIDIA GPU 世代之间发生显着变化,这就是为什么最低级别的接口在 Pallas:MGPU 中也不同。

每个 MMA 操作都与三个操作数相关联

  • 形状为 (M, N) 的累加器 D

  • 形状为 (M, K) 的左输入 A

  • 形状为 (K, N) 的右输入 B。所有操作数必须具有相同的元素类型。

每次使用 MMA 都涉及几个步骤

  1. 为累加器分配空间(MMA 隐式执行 D += A @ B

  2. 准备 AB 操作数

  3. 发出操作

  4. 等待操作完成

  5. 读出结果

步骤 2.-4. 通常在一个循环中对收缩维度 (K) 执行。

AB 操作数的内存空间#

AB 操作数通常最好通过 SMEM 传入,在那里可以使用 plgpu.copy_gmem_to_smem 方便地加载它们。为了使这些操作数与 MMA 操作兼容,它们需要在分配时指定适当的平铺和交织变换。对于所有当前支持的代系,TensorCore 要求数据以行优先的 2D 瓦片形式布局,形状为 (8, swizzle_elems),其中 swizzle_elems 通过将交织大小除以元素类型字节宽度得出。目前支持的交织大小为:128、64 和 32。较大的交织大小更可取,因为它们可以提高 GMEM 到 SMEM 复制的性能。

def mma_transforms(shape_dtype: jax.ShapeDtypeStruct):
  assert len(shape_dtype.shape) == 2
  if shape_dtype.shape[0] % 8:
    raise ValueError("Number of rows must be divisible by 8")
  for swizzle_bytes in (128, 64, 32):
    swizzle_elems = swizzle_bytes // shape_dtype.dtype.itemsize
    if shape_dtype.shape[-1] % swizzle_elems == 0:
      return (plgpu.TilingTransform((8, swizzle_elems)),
              plgpu.SwizzleTransform(swizzle_bytes))
  raise ValueError("Failed to find transforms for the specified window type")

如果操作数需要转换,则 A 操作数可以通过不同的内存空间传入(取决于架构,请参见下文)。B 操作数必须位于 SMEM 中。

转置操作数#

当对 16 位操作数执行 MMA 时,TensorCore 可以自动转置输入数据。例如,允许 A 引用具有形状 (K, M),但必须在将其传递到 mma 函数之前对其进行转置。例如

assert acc_ref.shape == (M, N) and a_ref.shape == (K, M) and b_ref.shape == (K, N)
a_ref_t = plgpu.transpose_ref(a_ref, (1, 0))
assert a_ref_t.shape == (M, K)  # The shape expected by plgpu.wgmma
plgpu.wgmma(acc, a_ref_t, b_ref)

在这种情况下,B 引用也允许进行类似的操作。

Hopper (wgmma)#

在本节中,我们将介绍使用 Hopper 代 TensorCore 的基础知识,该 TensorCore 在 PTX 中作为 wgmma.mma_async 指令 公开。

分配累加器#

在 Hopper 硬件架构中,累加器在寄存器中分配,但在 Pallas 中,它被建模为可变引用,因为每个 MMA 操作都进行原地累加。有两种方法可以分配累加器。

要创建零初始化的累加器,您可以使用带有 plgpu.ACC((m, n), dtype) 类型的 pl.run_scoped

def compute(acc_ref):
  ...
  return acc_ref[...]
output = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32))

取消引用累加器引用,如在 compute 函数末尾所见,将隐式地等待所有未完成的 WGMMA 操作。

如果您想使用现有数组对其进行初始化,则可以使用带有 plgpu.ACC.init(init_array)pl.run_state

def compute(acc_ref):
  ...
  return # pl.run_state only returns the final value of the accumulator
output = pl.run_state(compute)(plgpu.ACC.init(init_array))

如果 pl.run_state 具有累加器操作数,则它会在返回最终值之前隐式地等待所有未完成的 WGMMA 操作。

准备 AB 操作数#

如上所述,我们建议通过共享内存传入 AB。在这种情况下,必须指定正确的平铺和交织变换。

plgpu.wgmma 还允许通过寄存器传入 A(即,不是 SMEM 引用,而是作为常规 JAX 数组)。然而,这种模式存在许多显著的缺点,并且很难确保足够的同步来使其安全。

TODO:解释在这种情况下可以接受的条件。

发出操作#

支持的 MMA 形状是这样的:

  • M 可被 64 整除

  • N 可被 8 整除且小于 256

  • Kswizzle 除以元素类型字节宽度的倍数

目前支持的数据类型有:jnp.float32jnp.bfloat16jnp.float16。累加器 D 必须是 jnp.float32,但输入为 jnp.float16 的情况除外,在这种情况下,它也允许为 jnp.float16

等待操作完成#

每个 plgpu.wgmma 调用都与之前的所有 plgpu.wgmma 调用隐式同步,这样一旦控制从它返回,我们保证除了最后一个发出的 WGMMA 之外,没有其他 WGMMA 仍在运行。因此,先前发出的 WGMMA 指令读取的任何 SMEM 区域都可以重复使用。这对于将 WGMMA 与异步内存复制流水线化尤其重要

buffers = 3  # In reality you might want even more
assert a_smem.shape == (buffers, m, k)
assert b_smem.shape == (buffers, k, n)
assert acc_ref.shape == (m, n)

def fetch_a_b(ki, slot):
  a_slice = ... # Replace with the right M/K slice
  b_slice = ... # Replace with the right K/N slice
  plgpu.copy_gmem_to_smem(a_gmem.at[a_slice], a_smem.at[slot], a_loaded.at[slot])
  plgpu.copy_gmem_to_smem(b_gmem.at[b_slice], b_smem.at[slot], b_loaded.at[slot])

def loop_body(i, _):
  slot = jax.lax.rem(i, buffers)
  plgpu.barrier_wait(a_loaded.at[slot])
  plgpu.barrier_wait(b_loaded.at[slot])
  plgpu.wgmma(acc_ref, a_smem.at[slot], b_smem.at[slot])
  # We know that only the last issued WGMMA is running, so we can issue a async load in
  # into the other buffer
  load_i = i + buffers - 1
  load_slot = jax.lax.rem(load_i, buffers)
  @pl.when(jnp.logical_and(load_i >= buffers, load_i < num_steps))
  def _do_fetch():
    fetch_a_b(load_i, slot)
for slot in range(buffers):
  fetch_a_b(slot, slot)
jax.lax.fori_loop(0, num_steps, loop_body, None)

Blackwell (tcgen05)#

虽然 Mosaic GPU 支持 tcgen05 MMA 指令,但将此功能公开给 Pallas 仍在进行中。敬请期待!

使用 core_map#

TODO

同步结构和原语#

在本节中,我们将介绍用于线程之间同步的最重要函数和数据结构,以及一些异步操作。

commit_smem#

保证对引用的常规读取/写入操作产生与顺序程序顺序一致的值。例如,在以下程序中,保证 value 等于 value2

ref[...] = value
value2 = ref[...]

然而,这种保证不适用于异步原语,例如异步复制或 MMA 操作。为了使 SMEM 写入对这些原语可见,您需要使用 plgpu.commit_smem() 函数显式地与它们同步。

例如

smem_ref[...] = value
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(smem_ref, ...)

smem_ref[...] = value
plgpu.commit_smem()
plgpu.wgmma(smem_ref, ...)

未能调用此函数很可能导致细微的数据竞争,因为这些异步硬件单元从 SMEM 读取过时的数据。不幸的是,此函数相对昂贵,这就是为什么我们依靠您(用户)将其插入到必要的最少数量的位置。

Barrier#

这本质上是 PTX mbarrier 类型数组的薄封装,并作为引用传入。所有涉及屏障的函数都期望只获得单个屏障参数,因此如果引用包含多个,则必须使用 barriers.at[index] 显式提取其中一个。Barrier 始终在 SMEM 中分配,因此开销相对较低。每个屏障都可以配置为在固定数量的“到达”(默认值为 1)后完成。

要阻止线程直到屏障完成,请使用以下函数

plgpu.barrier_wait(barrier)

有三个操作可以完成屏障

至关重要的是要确保同步方案使得在调用 plgpu.barrier_wait 之前不可能发生两次屏障完成。例如,如果您使用 Barrier 来同步两个生产者/消费者线程,则需要执行双向屏障同步,以引入“背压”,这将阻止一个线程在另一个线程有机会等待之前到达两次。未能满足此要求将损坏数据结构,并可能导致令人惊讶的故障(包括 CUDA 运行时错误)。请参阅下文,查看包含两个线程的有效程序示例。

另一个关键限制是,在屏障的生命周期内,屏障完成的数量必须等于屏障等待的数量。不允许在屏障具有未等待完成时结束屏障的作用域分配。否则,当编译器重用它时,将其保持在这种状态可能会导致下游问题。

最后,至关重要的是要确保每个曾经等待 Barrier 的线程都参与其上的所有 wait 操作。例如,不允许从一个线程等待屏障的每隔一次完成,而从另一个线程等待所有其他完成。这样做会导致死锁。概括地说:当 Barrier 用于在某个线程中等待时,它必须观察该屏障的每次完成(通过等待它)。

请注意,Barrier 可以接收来自任何来源的到达,没有限制。

异步 GMEM 到 SMEM 复制#

当 TMA 引擎执行异步 GMEM 到 SMEM 复制时,它会将进度更新发布到提供给 plgpu.copy_gmem_to_smem 的屏障。复制完成后,屏障也将完成一次到达。

显式到达(跨线程同步)#

任何线程都可以使用以下函数显式地到达屏障

plgpu.barrier_arrive(barrier)

当同步处于生产者/消费者角色中的两个线程时,这尤其有用。在这种情况下,我们建议分配两个 Barrier 数组,其大小等于用于在两个线程之间传递数据的“队列”的大小。例如,假设一个线程继续将数组的瓦片写入 SMEM,而另一个线程读取它们。我们对 SMEM 区域进行三缓冲,以允许两个线程之间有更多的异步性

tid = jax.lax.axis_index("thread")
assert queue.shape == (buffering, *item_shape)
assert produced.shape == consumed.shape == (buffering,)

def thread0_body(i, _):
  slot = jax.lax.rem(i, buffering)
  @pl.when(i >= buffering)
  def _await_consumed():
    plgpu.barrier_wait(consumed.at[slot])  # Wait for consumption of the value before overwriting it
  # Option 1: Compute the next value
  queue[slot] = produce()
  plgpu.barrier_arrive(produced.at[slot])  # Signal the value is ready
  # Option 2: Produce the value through async_copy
  # plgpu.copy_gmem_to_smem(..., queue.at[slot], barrier=produced.at[slot])
pl.when(tid == 0)(lambda: jax.lax.fori_loop(0, steps, thread0_body, None))

def thread1_body(i, _):
  slot = jax.lax.rem(i, buffering)
  plgpu.barrier_wait(produced.at[slot])  # Wait for the value to be ready
  consume(queue[slot])  # Load and compute
  plgpu.barrier_arrive(consumed.at[slot])  # Signal that the value is consumed
pl.when(tid == 1)(lambda: jax.lax.fori_loop(0, steps, thread1_body, None))

等待 tcgen05 TensorCore 指令#

虽然 Mosaic GPU 支持 tcgen05 MMA 指令,但将此功能公开给 Pallas 仍在进行中。敬请期待!

ClusterBarrier#

TODO

Semaphore#

TODO

异步复制#

TODO

内联 Mosaic GPU#

TODO

编译器参数#

TODO