使用 Pallas 编写 Mosaic GPU 内核#
此页面是 Pallas:MGPU 后端最重要功能的参考。它不是教程,因此我们不期望每个人都从头到尾阅读它。尽管如此,还是值得浏览一下,以便熟悉您可以在其他教程中找到的一些模式。
在以下示例中,我们假设以下导入在作用域内
import jax.experimental.pallas as pl
import jax.experimental.pallas.mosaic_gpu as plgpu
什么是 GPU?#
从技术上讲,NVIDIA GPU 架构如下所示:GPU 被划分为流式多处理器 (SM)。这在 CUDA 编程模型中的体现是,每个 CUDA 线程块 (或 CTA) 都被调度到恰好一个 SM 上,但多个块可以同时调度到一个 SM 上。
每个 SM 包含一块称为共享内存 (SMEM) 的快速内存和 4 个细分,每个细分包含一个warp 调度器和计算单元(ALU、TensorCore、...)。这也反映在 CUDA 程序中:每个 warp(块中连续 32 个 CUDA 线程的组)以循环方式分配给这些细分之一。与块类似,每个 warp 被分配给恰好一个细分(它永远不会迁移),但多个 warp 可以被分配到同一个 SM 细分。在每个时钟周期,来自每个细分的 warp 调度器尝试选择其驻留 warp 之一来执行下一条指令。
更进一步,最近的 CUDA 版本还概述了 warpgroup 的概念,它是 4 个连续的 warp。了解硬件的结构后,我们可以看到它的来源:4 个连续的 warp 占据了 SM 的 4 个象限,并让我们发出利用整个 SM 的指令。
GPU 可以从许多不同的角度来看待,在这里我们想专注于一个略微简化的模型,该模型以 TensorCore 为中心。这应该有助于您驾驭编写涉及 TensorCore 的内核的复杂性,但请记住,真实情况更加复杂。
对于我们的目的而言,TensorCore 操作已经变得如此之大,以至于继续遵循 CUDA 模型不再有太大意义。因此,对我们来说,GPU 是单线程内核(SM)的集合,其中 Pallas:MGPU 的一个线程对应于一个 CUDA warpgroup。在这个模型中,您在内核中执行的每个操作都占用整个 CUDA warpgroup,并且其组成的 warp 始终以锁定步骤运行(考虑到硬件调度的抖动),并且永远不会通过控制流采取不同的路径(除了我们稍后将讨论的 core_map
之外的极小例外)。这里一个值得注意的补充是,我们仍然允许您在同一个 SM 上共同调度多个 Pallas 级别的线程,以便它们可以通过共享内存进行协作和通信(我们通过将它们放在同一个 CUDA 块中来实现)。
从现在开始,每当我们说“线程”时,我们指的是 Pallas 线程,而不是 CUDA 线程/lane。
这与 Triton 推广的编程模型非常相似,但正如您将看到的,存在一些差异。Mosaic GPU 往往更底层,这通常意味着您将不得不投入更多工作,但它也让您更多地掌控。在我们看来,这两种方法都有其优点,我们鼓励您选择最适合您需求的后端!Pallas 支持并将继续支持 Triton 作为备用 GPU 后端。
顺序执行 & 使用多个硬件单元#
与更复杂的 CPU 架构不同,GPU 仅支持顺序执行。然而,这并不意味着在任何给定时间只有一条指令在运行!每个 SM 象限都有多个独立的 functional units:TensorCore、算术逻辑单元 (ALU)、加载/存储 (LSU)、特殊功能单元 (SFU)。如果第一条指令针对其中一个单元,并且后面跟着另一条指令(不使用第一条指令的结果),那么 warp 调度器可以在第一条指令完成之前发出第二条指令。这通常被称为指令级并行性 (ILP),并且是现代 TensorCore 内核中的一个常见主题:TensorCore 操作非常大,并且需要很多周期才能完成,因此如果不在此时尝试使用其他单元,那将是一种浪费。
为了进一步扩展这一点,我们可以通过允许多个 Pallas 线程并发运行来利用这种硬件单元级别的并行性。如果一个线程主要占用 ALU,而另一个线程主要发出 TensorCore 相关指令,我们可以利用内置于 warp 调度器中的高效上下文切换来保持两个单元都处于忙碌状态。这是 FlashAttention 3 或 CUTLASS ping-pong matmul kernels 等算法背后的核心思想之一。
有关 warp 调度和指令发出如何工作的更多信息,我们建议阅读 Analyzing Modern NVIDIA GPU cores。
内存空间#
GPU 具有几个不同的内存空间,这些空间可以根据容量大小和速度(总带宽和单次访问延迟)进行完全排序。
最大的内存空间是 plgpu.GMEM
,代表全局内存。在最近的数据中心级 GPU 中,此内存空间通常以数十甚至数百 GB 为单位衡量,但它也是最慢的。
下一个内存空间,用于 L2 缓存,在某种程度上也是全局的,因为它由整个 GPU 共享,但其使用只能通过缓存提示间接影响。因此,没有办法手动将值放置在那里,因此 Pallas:MGPU 中未公开此内存空间。虽然只有大约 100MB 的大小,但此内存的带宽比 GMEM 高得多,因此在编写高性能内核时,仍然经常建议利用它。
接下来是共享内存,或 plgpu.SMEM
。此内存直接位于每个 SM 内部,因此它是分区的。除非使用块集群(请参阅下面的集群部分),否则每个块仅允许访问其自己的 SMEM 分配。
最后,最低级别的内存空间是寄存器内存。这是 Pallas 内核中每个单个值(即 JAX 数组)将要驻留的位置。如果编译器用完寄存器来存储这些数组,它将插入spills,这意味着它将定期存储和重新加载值到内存。这些 spills 通常会引入其他显著的性能下降,因此我们建议避免它们。关于 spills 的警告消息可以在内核编译期间的 ptxas
消息中清楚地看到。要使它们可见,请在您的环境中运行 MOSAIC_GPU_DUMP_PTXAS=1
。
Blackwell GPU 世代有一个额外的内存空间,称为tensor memory 或 plgpu.TMEM
。TMEM 与寄存器内存非常相似,只是它由您显式分配和管理。它用于存储 MMA 累加器、操作数元数据(用于稀疏性或缩放),以及可选的左 MMA 操作数。有关 TMEM 的更多信息,请参阅 Blackwell MMA 部分。
在特定内存空间中请求/分配内存#
内核输入或输出默认放置在 SMEM 中。如果您想将它们作为 GMEM 引用访问,请将 memory_space=plgpu.GMEM
添加到它们的 BlockSpec
中。如果您希望使用 GMEM 中的整个输入或输出数组来调用内核,则指定 BlockSpec(memory_space=plgpu.GMEM)
就足够了。
SMEM
和 TMEM
可以在 pl.pallas_call
的 scratch_shapes
参数中显式分配,或使用 pl.run_scoped
分配。要分配引用,只需使用请求的形状和 dtype 调用内存空间对象。例如:plgpu.SMEM((128, 128), jnp.float16)
将在共享内存中分配一个 128x128 的 float16 元素数组。
利用 L2 缓存#
虽然 L2 缓存无法手动管理,但与全局内存相比,其明显更高的带宽使其值得考虑。利用它的最简单方法是重新排序并行网格维度,以便在相似时间段内调度的调用也访问相同的输入数据。
虽然 CUDA 编程模型不保证关于块分配给 SM 的顺序的任何内容,但在最近的世代中,启发式方法似乎只是以列优先顺序迭代 (x, y, z)
CUDA 网格(即 x
是变化最快的维度,而 z
是变化最慢的维度)。类似地,Pallas:MGPU 不保证用户指定的网格如何映射到 CUDA 网格(Pallas 支持任意秩的网格,而不仅仅是最多 3D)。但是,您可以假设迭代将以行优先顺序进行。也就是说,如果网格具有维度 (a, b)
,则 b
将是变化最快的维度,而 a
将是较慢的维度。
为了给出一个实际的例子,考虑一个普通的矩阵乘法内核。在那里,通常使用两个并行网格维度 但是,如果我们简单地将网格重新排列为 请注意,即使活动块的数量没有改变,它们访问的数据的总 footprint 也减少了一半!我们现在获得 L2 命中的机会更高。(m, n)
,对应于平铺两个非收缩维度。如果我们使用这个简单的方案,在 Pallas:MGPU 中,所有程序 id 为 (0, (1, m=0
的程序必须读取所有的 B
操作数!如果 n
或 k
维度非常大,那么我们就不可能从 (0, (1,
Your browser does not support SVGs or scripting is disabled.
This would be an image showing the access pattern of first 16 blocks without grid tiling.
(m // mt, n, mt)
(然后在内核中用 pl.program_id(0) * mt + pl.program_id(2)
替换 pl.program_id(0)
),很容易看出沿两个维度的程序带将被并发调度(而不是调度单行)。这大大增加了并发程序的数量,这些程序加载相似的数据切片,通常显着提高了 L2 利用率,从而提高了内核的整体性能(如果它是内存受限的)。继续我们有 16 个块的示例并使用 mt=4
,我们得到以下访问模式
数组布局和内存引用转换#
在 Pallas 中,您使用的数据结构(数组和引用)具有逻辑形状(例如,128x128 矩阵)。此逻辑形状必须映射到物理表示(数据在 GPU 内存中的实际表示方式)。具体的映射取决于数据驻留的位置
数组布局: 数组存储在寄存器内存中,我们将此映射称为布局。布局定义了数组的元素如何在构成 Pallas 线程的 CUDA lane 可用的寄存器之间分布。
内存引用转换: 对于指向
SMEM
的可变引用,此映射称为转换。转换描述了逻辑数据结构如何在该内存块中排列。
这些概念对于性能至关重要,尤其是在与 TensorCore 等专用硬件单元交互或优化内存访问模式时。
我们正在开发一种模式,该模式将完全自动地处理分配布局和转换(尽管有提供提示和更多控制的方法)。下面列出的 API 可能会继续起作用,但将变为可选。
内存引用转换#
转换在首次分配内存引用时应用。对这些引用进行操作的 Pallas 原语将自动考虑与其关联的转换。
def body(..., scratch_ref):
# Asynchronous copy will reformat the GMEM data to match the SMEM transforms
plgpu.copy_gmem_to_smem(..., scratch_ref, barrier)
barrier.wait()
plgpu.wgmma(..., scratch_ref) # wgmma only accepts properly transformed refs
...
引用分配有两种方式,每种方式都有一种选择所需转换的方法
1. 使用 GPUBlockSpec
transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128))
f = pl.pallas_call(
in_specs=plgpu.GPUBlockSpec(in_block_shape, in_index_map, transforms=transforms),
out_specs=plgpu.GPUBlockSpec(out_block_shape, out_index_map, transforms=transforms),
...
)
2. 在分配的 SMEM
上指定 transforms
参数
transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128))
f = pl.pallas_call(
scratch_shapes=plgpu.SMEM((128, 128), jnp.float16, transforms=transforms),
...
)
可用的转换是
plgpu.TileTransform(tile_shape)
,它将数据组织成形状为tile_shape
的连续的、非重叠的 tile。一个 tile 的数据始终完全线性化(行优先),然后再开始另一个 tile(tile 也以行优先顺序遍历)。例如,将TileTransform((8, 64))
应用于(128, 128)
引用意味着对应于逻辑切片[0:8, 0:64]
的数据将首先存储(行优先),然后是[0:8, 64:128], [8:16, 0:64], [8:16, 64:128]
,依此类推。实现此目的的另一种方法是获取输入数组x
并以行优先顺序遍历x.reshape(128 // 8, 128 // 64, 8, 64).transpose(0, 2, 1, 3)
。plgpu.SwizzleTransform(swizzle_in_bytes)
,它按照 PTX 文档 和 CUDA 文档 中描述的方式转换数据。Swizzling 非常有用,因为它允许在寄存器和共享内存之间传输 MMA 相关布局中的数据,而不会发生 bank 冲突。在 swizzling 之后内存看起来像什么的确切细节并不那么重要,因为所有原语都会自动考虑它。请注意,swizzle 量以字节为单位指定(仅支持 128、64、32 和 16),并且通常伴随TileTransform
(在其形状中使用元素!)。plgpu.TransposeTransform(permutation)
,它在数组线性化之前排列数组的维度。这主要有用之处在于,它允许您在 GMEM-SMEM 复制期间更改布局(请记住,硬件不支持更改最次要/最后一个维度)。
数组布局#
到目前为止,我们为您定义了一些有用的布局
plgpu.Layout.WGMMA
,这是 Hopper 世代 TensorCore 期望 MMA 累加器或 16 位输入操作数在寄存器中拥有的布局。plgpu.Layout.WGMMA_ROW
,这是在沿行减少后获得的上述布局。重新广播行是免费的,并将生成具有WGMMA
布局的值。plgpu.Layout.WGMMA_COL
,它是上述布局的类似物,只是沿列而不是行减少。plgpu.Layout.WG_STRIDED
,其中值在组成 Pallas 线程的 128 个 CUDA lane 之间平均分区。连续的元素(在向量化之后)以循环方式分配给 lane。当不需要与 TensorCore 交互时,非常简单有效。plgpu.Layout.WG_SPLAT
,指示该值是常数。每个 CUDA lane 将保存一个包含该值的寄存器。您通常永远不必与此布局交互,因为它在创建常量值时隐式使用,并且始终可以隐式转换为其他布局。
目前,在默认操作模式下,数组布局传播仅在向前方向发生,并且对协调布局冲突的隐式支持很少:只有 splat 布局可以隐式转换为任何其他布局。例如,如果您尝试添加两个具有不同布局的数组,则降低将抱怨并失败。有一些非常有限的工具可以让您在布局之间进行转换,我们通常建议将值存储到 SMEM 并以目标布局将其读回。
MMA (TensorCore)#
在本节中,我们重点介绍 Pallas:MGPU 内核如何利用 TensorCore 单元。TensorCore 的编程接口在不同的 NVIDIA GPU 世代之间发生显着变化,这就是为什么最低级别的接口在 Pallas:MGPU 中也不同。
每个 MMA 操作都与三个操作数相关联
形状为
(M, N)
的累加器D
,形状为
(M, K)
的左输入A
,形状为
(K, N)
的右输入B
。所有操作数必须具有相同的元素类型。
每次使用 MMA 都涉及几个步骤
为累加器分配空间(MMA 隐式执行
D += A @ B
)准备
A
和B
操作数发出操作
等待操作完成
读出结果
步骤 2.-4. 通常在一个循环中对收缩维度 (K
) 执行。
A
和 B
操作数的内存空间#
A
和 B
操作数通常最好通过 SMEM 传入,在那里可以使用 plgpu.copy_gmem_to_smem
方便地加载它们。为了使这些操作数与 MMA 操作兼容,它们需要在分配时指定适当的平铺和交织变换。对于所有当前支持的代系,TensorCore 要求数据以行优先的 2D 瓦片形式布局,形状为 (8, swizzle_elems)
,其中 swizzle_elems
通过将交织大小除以元素类型字节宽度得出。目前支持的交织大小为:128、64 和 32。较大的交织大小更可取,因为它们可以提高 GMEM 到 SMEM 复制的性能。
def mma_transforms(shape_dtype: jax.ShapeDtypeStruct):
assert len(shape_dtype.shape) == 2
if shape_dtype.shape[0] % 8:
raise ValueError("Number of rows must be divisible by 8")
for swizzle_bytes in (128, 64, 32):
swizzle_elems = swizzle_bytes // shape_dtype.dtype.itemsize
if shape_dtype.shape[-1] % swizzle_elems == 0:
return (plgpu.TilingTransform((8, swizzle_elems)),
plgpu.SwizzleTransform(swizzle_bytes))
raise ValueError("Failed to find transforms for the specified window type")
如果操作数需要转换,则 A
操作数可以通过不同的内存空间传入(取决于架构,请参见下文)。B
操作数必须位于 SMEM 中。
转置操作数#
当对 16 位操作数执行 MMA 时,TensorCore 可以自动转置输入数据。例如,允许 A
引用具有形状 (K, M)
,但必须在将其传递到 mma 函数之前对其进行转置。例如
assert acc_ref.shape == (M, N) and a_ref.shape == (K, M) and b_ref.shape == (K, N)
a_ref_t = plgpu.transpose_ref(a_ref, (1, 0))
assert a_ref_t.shape == (M, K) # The shape expected by plgpu.wgmma
plgpu.wgmma(acc, a_ref_t, b_ref)
在这种情况下,B
引用也允许进行类似的操作。
Hopper (wgmma
)#
在本节中,我们将介绍使用 Hopper 代 TensorCore 的基础知识,该 TensorCore 在 PTX 中作为 wgmma.mma_async
指令 公开。
分配累加器#
在 Hopper 硬件架构中,累加器在寄存器中分配,但在 Pallas 中,它被建模为可变引用,因为每个 MMA 操作都进行原地累加。有两种方法可以分配累加器。
要创建零初始化的累加器,您可以使用带有 plgpu.ACC((m, n), dtype)
类型的 pl.run_scoped
。
def compute(acc_ref):
...
return acc_ref[...]
output = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32))
取消引用累加器引用,如在 compute
函数末尾所见,将隐式地等待所有未完成的 WGMMA 操作。
如果您想使用现有数组对其进行初始化,则可以使用带有 plgpu.ACC.init(init_array)
的 pl.run_state
def compute(acc_ref):
...
return # pl.run_state only returns the final value of the accumulator
output = pl.run_state(compute)(plgpu.ACC.init(init_array))
如果 pl.run_state
具有累加器操作数,则它会在返回最终值之前隐式地等待所有未完成的 WGMMA 操作。
准备 A
和 B
操作数#
如上所述,我们建议通过共享内存传入 A
和 B
。在这种情况下,必须指定正确的平铺和交织变换。
plgpu.wgmma
还允许通过寄存器传入 A
(即,不是 SMEM 引用,而是作为常规 JAX 数组)。然而,这种模式存在许多显著的缺点,并且很难确保足够的同步来使其安全。
TODO:解释在这种情况下可以接受的条件。
发出操作#
支持的 MMA 形状是这样的:
M
可被 64 整除N
可被 8 整除且小于 256K
是swizzle
除以元素类型字节宽度的倍数
目前支持的数据类型有:jnp.float32
、jnp.bfloat16
和 jnp.float16
。累加器 D
必须是 jnp.float32
,但输入为 jnp.float16
的情况除外,在这种情况下,它也允许为 jnp.float16
。
等待操作完成#
每个 plgpu.wgmma
调用都与之前的所有 plgpu.wgmma
调用隐式同步,这样一旦控制从它返回,我们保证除了最后一个发出的 WGMMA 之外,没有其他 WGMMA 仍在运行。因此,先前发出的 WGMMA 指令读取的任何 SMEM 区域都可以重复使用。这对于将 WGMMA 与异步内存复制流水线化尤其重要
buffers = 3 # In reality you might want even more
assert a_smem.shape == (buffers, m, k)
assert b_smem.shape == (buffers, k, n)
assert acc_ref.shape == (m, n)
def fetch_a_b(ki, slot):
a_slice = ... # Replace with the right M/K slice
b_slice = ... # Replace with the right K/N slice
plgpu.copy_gmem_to_smem(a_gmem.at[a_slice], a_smem.at[slot], a_loaded.at[slot])
plgpu.copy_gmem_to_smem(b_gmem.at[b_slice], b_smem.at[slot], b_loaded.at[slot])
def loop_body(i, _):
slot = jax.lax.rem(i, buffers)
plgpu.barrier_wait(a_loaded.at[slot])
plgpu.barrier_wait(b_loaded.at[slot])
plgpu.wgmma(acc_ref, a_smem.at[slot], b_smem.at[slot])
# We know that only the last issued WGMMA is running, so we can issue a async load in
# into the other buffer
load_i = i + buffers - 1
load_slot = jax.lax.rem(load_i, buffers)
@pl.when(jnp.logical_and(load_i >= buffers, load_i < num_steps))
def _do_fetch():
fetch_a_b(load_i, slot)
for slot in range(buffers):
fetch_a_b(slot, slot)
jax.lax.fori_loop(0, num_steps, loop_body, None)
Blackwell (tcgen05
)#
虽然 Mosaic GPU 支持 tcgen05
MMA 指令,但将此功能公开给 Pallas 仍在进行中。敬请期待!
使用 core_map
#
TODO
同步结构和原语#
在本节中,我们将介绍用于线程之间同步的最重要函数和数据结构,以及一些异步操作。
commit_smem
#
保证对引用的常规读取/写入操作产生与顺序程序顺序一致的值。例如,在以下程序中,保证 value
等于 value2
。
ref[...] = value
value2 = ref[...]
然而,这种保证不适用于异步原语,例如异步复制或 MMA 操作。为了使 SMEM 写入对这些原语可见,您需要使用 plgpu.commit_smem()
函数显式地与它们同步。
例如
smem_ref[...] = value
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(smem_ref, ...)
或
smem_ref[...] = value
plgpu.commit_smem()
plgpu.wgmma(smem_ref, ...)
未能调用此函数很可能导致细微的数据竞争,因为这些异步硬件单元从 SMEM 读取过时的数据。不幸的是,此函数相对昂贵,这就是为什么我们依靠您(用户)将其插入到必要的最少数量的位置。
Barrier
#
这本质上是 PTX mbarrier
类型数组的薄封装,并作为引用传入。所有涉及屏障的函数都期望只获得单个屏障参数,因此如果引用包含多个,则必须使用 barriers.at[index]
显式提取其中一个。Barrier
始终在 SMEM 中分配,因此开销相对较低。每个屏障都可以配置为在固定数量的“到达”(默认值为 1)后完成。
要阻止线程直到屏障完成,请使用以下函数
plgpu.barrier_wait(barrier)
有三个操作可以完成屏障
至关重要的是要确保同步方案使得在调用
plgpu.barrier_wait
之前不可能发生两次屏障完成。例如,如果您使用Barrier
来同步两个生产者/消费者线程,则需要执行双向屏障同步,以引入“背压”,这将阻止一个线程在另一个线程有机会等待之前到达两次。未能满足此要求将损坏数据结构,并可能导致令人惊讶的故障(包括 CUDA 运行时错误)。请参阅下文,查看包含两个线程的有效程序示例。
另一个关键限制是,在屏障的生命周期内,屏障完成的数量必须等于屏障等待的数量。不允许在屏障具有未等待完成时结束屏障的作用域分配。否则,当编译器重用它时,将其保持在这种状态可能会导致下游问题。
最后,至关重要的是要确保每个曾经等待
Barrier
的线程都参与其上的所有wait
操作。例如,不允许从一个线程等待屏障的每隔一次完成,而从另一个线程等待所有其他完成。这样做会导致死锁。概括地说:当Barrier
用于在某个线程中等待时,它必须观察该屏障的每次完成(通过等待它)。
请注意,Barrier
可以接收来自任何来源的到达,没有限制。
异步 GMEM 到 SMEM 复制#
当 TMA 引擎执行异步 GMEM 到 SMEM 复制时,它会将进度更新发布到提供给 plgpu.copy_gmem_to_smem
的屏障。复制完成后,屏障也将完成一次到达。
显式到达(跨线程同步)#
任何线程都可以使用以下函数显式地到达屏障
plgpu.barrier_arrive(barrier)
当同步处于生产者/消费者角色中的两个线程时,这尤其有用。在这种情况下,我们建议分配两个 Barrier
数组,其大小等于用于在两个线程之间传递数据的“队列”的大小。例如,假设一个线程继续将数组的瓦片写入 SMEM,而另一个线程读取它们。我们对 SMEM 区域进行三缓冲,以允许两个线程之间有更多的异步性
tid = jax.lax.axis_index("thread")
assert queue.shape == (buffering, *item_shape)
assert produced.shape == consumed.shape == (buffering,)
def thread0_body(i, _):
slot = jax.lax.rem(i, buffering)
@pl.when(i >= buffering)
def _await_consumed():
plgpu.barrier_wait(consumed.at[slot]) # Wait for consumption of the value before overwriting it
# Option 1: Compute the next value
queue[slot] = produce()
plgpu.barrier_arrive(produced.at[slot]) # Signal the value is ready
# Option 2: Produce the value through async_copy
# plgpu.copy_gmem_to_smem(..., queue.at[slot], barrier=produced.at[slot])
pl.when(tid == 0)(lambda: jax.lax.fori_loop(0, steps, thread0_body, None))
def thread1_body(i, _):
slot = jax.lax.rem(i, buffering)
plgpu.barrier_wait(produced.at[slot]) # Wait for the value to be ready
consume(queue[slot]) # Load and compute
plgpu.barrier_arrive(consumed.at[slot]) # Signal that the value is consumed
pl.when(tid == 1)(lambda: jax.lax.fori_loop(0, steps, thread1_body, None))
等待 tcgen05
TensorCore 指令#
虽然 Mosaic GPU 支持 tcgen05
MMA 指令,但将此功能公开给 Pallas 仍在进行中。敬请期待!
ClusterBarrier
#
TODO
Semaphore
#
TODO
异步复制#
TODO
内联 Mosaic GPU#
TODO
编译器参数#
TODO