Pallas 设计#
本文档将解释 Pallas 的初始设计。这是早期设计决策的一个快照,Pallas 的具体 API 可能自那时起已经发生变化。
引言#
JAX 被用于各种各样的工作负载,从大规模机器学习到科学计算。JAX 的成功故事在很大程度上也是 XLA 的成功故事,XLA 是 JAX 主要的目标编译器 — XLA 编译 JAX 程序以在加速器上运行,并使 JAX 能够扩展到最大的 ML 模型。JAX 在 XLA 的表示 HLO 中描述逻辑计算。HLO 描述了计算在逻辑上是如何发生的,但没有描述物理上的发生方式。给定一个逻辑 HLO 计算,XLA 决定该计算的物理执行方式。对于各种 ML 应用,XLA 在编译用户程序方面做得很好,但不可避免地会遇到 XLA 的限制。在这些情况下,我们需要提供一个“逃生舱口”,允许专家编写手动调优的内核,以在当时超越 XLA 的性能。此外,ML 系统研究的进展需要一些时间才能被纳入 XLA,而用户通常希望能够利用这些进展。随着时间的推移,编译器可以整合通过手动调优内核在实验中被证明有效的优化。
XLA 确实提供了 CustomCall 机制作为逃生舱口,但它要求用户编写 C++,并且在 GPU 上需要用户学习 CUDA 编程模型。CUDA 编程模型对于许多机器学习 GPU 内核(如矩阵乘法)来说可能过于底层,即使是专家用户也很难使用 CUDA 来实现高效的矩阵乘法或多头注意力。不仅如此,JAX 用户通常熟悉 Python 和 NumPy 风格的数组编程,这不需要编写任何 C++ 或考虑 GPU 并行性。所有流行的机器学习框架都遵循这个理念:通过 matmul 或 convolution 等高级操作来操作(通常是)数组。不幸的是,这意味着通过 CustomCall 实现自定义操作是一项重大的投资,可能需要学习 C++ 和/或 GPU 编程。
Triton 是 OpenAI 构建和维护的一个 GPU 编译器,它在 ML 编译器领域掀起了一场风暴。Triton 提供了两全其美:一个用于 GPU 内核的基于数组的编程模型。Triton 是 PyTorch 2.0 中 torch.compile 的主要代码生成路径,通过 Torch Inductor 库实现。Triton 积极地隐藏了 GPU 编程的某些方面,以提供一个更易于访问的编程模型,该模型可以从 Python 中使用,并从更高级别的表示生成优化代码。虽然 GPU 比 Triton 提供的更灵活,但在 ML 领域,Triton 对于许多应用来说似乎足够表达。
本文档将介绍 Pallas,它是 JAX 的一个扩展,支持使用类似 Triton 的模型进行 GPU 和 TPU 的内核编程。基于 JAX 的内核语言提供了几个优势:
尽管 Triton 向用户暴露了类似 TPU 的编程模型,即为 L1 缓存中的数组块编写程序,但它足够专注于 GPU,我们无法直接为 TPU 编译 Triton。例如,Triton 提供了专门用于处理并行写入的原子操作,而这在 TPU 上不一定有意义。一个更高级的前端可以抽象出平台的细节,同时仅暴露基于块的编程模型。因此,内核将能够跨不同的硬件平台移植。
JAX 作为一种基于追踪的数值计算前端,既成熟又被广泛使用。通过将内核编程语言嵌入 JAX 本身,我们可以重用 JAX 的追踪基础设施,并提供用户已经熟悉的类似 NumPy 的前端。
JAX 的转换是其成功的关键,它允许用户表达简单的程序,然后将它们转换为实现复杂的功能。我们可以利用相同的转换(vmap、jvp 等)来转换用户编写的内核。
悬而未决的问题是:JAX 是否适合作为内核语言?我们认为是。Triton 表明基于数组的编程语言可以用于编写 GPU 内核,而 JAX 正是如此。JAX 也已被证明是编译器和程序转换的灵活前端。
我们将 Pallas 描述如下:首先,我们将描述我们扩展 JAX 以支持编写自定义内核的方法。然后,我们将展示如何将 Pallas 降低到 Triton 和 Mosaic。最后,我们将描述通过 JAX 转换来转换 Pallas 内核的现有和潜在方法。
Pallas 降低路径可视化
Pallas:扩展 JAX 以支持内核#
我们想强调的关键点是,Pallas 只是 JAX,带有一些扩展。
用户现在在他们的 JAX 代码中使用称为
Ref的引用类型。这为用户提供了更精确的内存访问控制,并且 JAX 中的内存布局将更接近物理布局。用户使用 JAX 原语的子集以及一套 Pallas 特定的原语来编写他们的 JAX 程序。
用户通过一个特殊的
pallas_call高阶函数将 Pallas 内核嵌入到外部 JAX 程序中,该函数在一个映射中执行内核。它类似于pmap或shard_map,但带有对共享内存的引用。
我们将一一通过示例来介绍这三个扩展。
请注意,这些 API 仍处于实验阶段,可能会发生变化。
引用类型#
让我们看一个用于相加两个向量的 Pallas 程序示例。
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
def add_kernel(x_ref, y_ref, o_ref):
# In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
x = x_ref[:]
y = y_ref[:]
o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
add(x, y)
与常规 JAX 程序不同,add_kernel 不接收不可变数组参数。相反,它接收引用,可以通过类似 NumPy 的语法就地读写。 Refs 并不是 Pallas 特有的概念 — 它们被引入 JAX 来表示有状态的计算。但是,我们可以在编写操作可变内存的内核时利用它们。
Pallas 内核不仅接收对应于内核输入的 Refs,还接收输出的 Refs(在 pallas_call 中通过 out_shape 指定)。Refs 是特殊类型,如果不先读取,就无法将其传递给常规 JAX 原语。当你从 Ref 读取时,你会得到一个 JAX Array 类型,并且你必须将一个 Array 写入 Ref。
读写 Ref#
从 Ref 读取相当于将数组加载到内存层次结构的最低级别(GPU 上的 L1 缓存和 TPU 上的向量寄存器)。写入 Refs 也是类似的。
def f(x_ref, o_ref):
# Using vanilla Python indexing
x = x_ref[0, 2:5, :]
# Or via Numpy advanced int indexing
o_ref[jnp.arange(3), :] = x
# Note that in order to use NumPy advanced int indexing, you need to broadcast the indices against each other into the desired multidimensional shape:
def f(x_ref):
# Assume x_ref is (8, 4) and we want to read out a (2, 3) slice
x = x_ref[jnp.arange(2)[..., None], jnp.arange(3)[None, ...]]
可以通过类似的 __setitem__ 风格的索引来写入 Refs。
其他形式的索引(例如,动态切片)可以通过 pallas.load 和 pallas.store 来完成,这是新设计的 JAX 原语,用于简化从内存加载/存储到内存的操作。我们稍后将讨论这些新原语。
使用新的 Pallas 原语扩展 JAX#
由于 JAX 是围绕 HLO 设计的,JAX 原语集紧密地映射了 HLO 操作集。针对新的编译器(例如 Triton 或 Mosaic)意味着我们可能需要用针对新编译器的新原语来补充 JAX 的原语。同时,我们也可能无法降低所有 JAX 原语,因此需要将其限制在一个子集中。
由于 Pallas 最初是为 Triton 设计的,因此我们提供了一套针对 Triton 编程模型的新原语。正如我们稍后将展示的,我们也可以将这些原语降低到 Mosaic。
pallas.load 和 pallas.store#
pallas.load 和 pallas.store 是允许从内存加载和向内存存储的原语。与 __getitem__ 和 __setitem__ 不同,它们更灵活,但代价是更冗长。具体来说,你可以使用 pallas.dynamic_slice(简称 pallas.ds)构造(这也许应该被上游到 JAX 中,以便与 Ref 的 __getitem__ 和 __setitem__ 一起使用)。
def f(x_ref, o_ref):
# Reading from memory via pallas.load
x = pl.load(x_ref, (0, slice(2, 5), slice(None)))
# Using integer indexing automatically broadcasts
x = pl.load(x_ref, (0, 2 + jnp.arange(3), slice(None)))
# You can also use `pl.dynamic_slice` (`pl.ds` for short) objects as well
pl.store(o_ref, (0, pl.ds(start=2, size=3), slice(None)), x)
pallas.load 和 pallas.store 也通过 mask 参数支持掩码。
def f(x_ref, o_ref):
# Reading from memory via pallas.load
idx = jnp.arange(8)
mask = idx < 5
x = pl.load(x_ref, (idx,), mask=mask, other=float('-inf'))
当进行越界加载/存储时,掩码很重要。掩码的操作语义可以由编译器确定(如果我们正确理解文档,Triton 会避免读取/写入被掩码的内存)。
pallas.program_id 和 pallas.num_programs#
正如我们很快就会看到的,我们将多次执行相同的 Pallas 内核(根据后端是并行执行还是流水线执行)。这些新原语告诉我们内核执行的“位置”。
pallas.program_id 接受一个轴参数,该参数告诉我们内核当前正在多维网格的哪个轴上执行(类似于 CUDA 编程中的 threadId 或 jax.pmap 中的 lax.axis_index)。请注意,我们目前借用了 Triton 的“program”术语,将来我们可能会想将其更改为 JAX 用户更熟悉的术语。
def f(x_ref, o_ref):
i = pl.program_id(axis=0) # execution index in the first axis of the grid
o_ref[i] = jnp.exp(x_ref[i])
pallas.num_programs 也接受一个轴并返回该轴的网格大小。
请注意,虽然 program_id 和 num_programs 是 Triton 特定的术语,但它们很容易泛化,在 TPU 上也同样有意义。
在 Pallas 中使用 JAX 原语的子集#
因为我们编写的是内核而不是高级 HLO 程序,所以某些 JAX 原语可能无法高效地表示在底层基元中。然而,我们知道我们可以支持大多数元素级操作、简单的点积和 JAX 控制流。
虽然我们还没有完全确定 Pallas 内核支持的所有 JAX 原语,但我们当然可以确定一些不易降低或不太可能有用的原语。
conv_general- 卷积通常不作为底层硬件中的原语提供。gather/scatter- 底层编译器可能不支持非连续内存的读写。
使用 pallas_call 执行 Pallas 内核#
既然我们已经编写了 Pallas 内核(又名带有 Refs 和 Pallas 额外原语的 JAX),我们如何在 GPU 或 TPU 上执行它们?我们使用 pallas_call,这是一个高阶函数(类似于 jax.jit 和 jax.pmap),用于执行内核。
pallas_call 的签名如下:
def pallas_call(
kernel: Callable,
out_shape: Sequence[jax.ShapeDtypeStruct],
*,
in_specs: Sequence[Spec],
out_specs: Sequence[Spec],
grid: Optional[Tuple[int, ...]] = None) -> Callable:
...
当我们将内核提供给 pallas_call 时,我们会提供额外的信息。第一个是 out_shape,它告诉内核输出是什么样子(pallas_call 将一个对应的 Ref 传递给内核以供写入)。其余信息(in_specs、out_specs 和 grid)是关于内核如何在加速器上调度的信息。
pallas_call 的(大致)语义如下:
def pallas_call(kernel, out_shape, *, in_specs, out_specs, grid):
def execute(*args):
outputs = map(empty_ref, out_shape)
grid_indices = map(range, grid)
for indices in itertools.product(*grid_indices): # Could run in parallel!
local_inputs = [in_spec.transform(arg, indices) for arg, in_spec in
zip(args, in_specs)]
local_outputs = [out_spec.transform(arg, indices) for arg, out_spec in
zip(outputs, out_specs)]
kernel(*local_inputs, *local_outputs) # writes to outputs
return execute
具体来说,pallas_call 将“循环”遍历网格迭代空间,应用一个通过 in_specs 和 out_specs 指定的输入输出转换。在每次迭代中,内核将根据转换后的输入和输出来调用。请注意,遍历迭代空间的“循环”可以并行执行(例如,在 GPU 上)。pallas_call 也没有对循环迭代顺序的任何保证,只是保证迭代空间中的每个成员都会被遍历。像 Triton 和 Mosaic 这样的编译器将具有与网格相关的更具体的操作语义。
转换函数#
pallas_call 的 in_specs 和 out_specs 参数允许以某种方式转换输入和输出。Pallas 目前提供的两个选项是身份转换(输入和输出保持不变)和 BlockSpecs,它们根据循环索引从 Refs 中获取固定大小的块。
BlockSpec 接受一个 index_map 函数和一个 block_shape。逻辑上,它将一个数组沿着每个轴切分成 block_shape 大小的块。 index_map 函数接受循环索引(来自网格索引集)并将其映射到块索引。转换函数将 Refs 转换为对应块的 Ref 的逻辑视图。当我们指定 block_shape 中的条目为 None 时,这对应于在该维度上“映射”,将其从内核内的块中移除。
class BlockSpec:
index_map: Callable[[Tuple[Int, ...]], Tuple[Int, ...]]
block_shape: Tuple[Optional[int], ...]
def transform(self, ref, *loop_indices):
block_indices = self.transform_function(loop_indices)
# Returns a view of `ref` starting at `block_indices` of shape self.block_shape
...
我们也可以设想其他与 pallas_call 一起使用的 Specs,例如一个 Spec,它对应于重叠窗口,例如,用于实现卷积。
Pallas 作为前端的直接优势#
通过提供一个用于内核编写的 JAX 前端,我们可以立即获得一些好处。
更灵活的前端#
第一个是 JAX 用户已经习惯了使用 JAX 及其基于追踪的转换进行编程的好处(和限制)。这意味着用户可以在编写 Pallas 内核时使用闭包和其他熟悉的 Python 构造。这与现有的基于 AST 解析的 Triton 前端或 Mosaic 的 MLIR 构建器不同。例如,这使得 Pallas 比 Triton 更容易进行模板化。
请参阅此示例,了解如何使用 Python 中的高阶函数来模板化内核。
def make_kernel(eltwise_kernel):
def add(x_ref, y_ref, o_ref):
x = pl.load(x_ref, ())
y = pl.load(y_ref, ())
pl.store(o_ref, (), eltwise_kernel(x + y))
return add
kernel1 = make_kernel(lambda x: x * 2)
kernel2 = make_kernel(jnp.exp)
pl.pallas_call(kernel1, out_shape=x, grid=1)(1., 1.)
pl.pallas_call(kernel2, out_shape=x, grid=1)(1., 1.)
模拟模式#
通过将内核表示为带有 JAX 原语和一些新 Pallas 原语的程序,我们还可以将 Pallas 程序降低到 StableHLO,并使用 XLA 进行编译/执行。具体来说,pallas_call 可以实现为网格上的 lax.scan。这使得我们可以在任何 XLA 支持的平台上(甚至 CPU!)开发 GPU 或 TPU 内核,并使用 JAX/XLA 调试工具(如 jax.debug.print)进行调试。我们还可以使用更可靠、经过更好测试的 XLA 数值来验证 Triton 和 Mosaic 编译器的正确性。人们还可以设想扰乱 scan 顺序来模拟 GPU 上发生的并行读写。
GPU 示例#
请注意,以下所有示例仅适用于 GPU。它们需要调整块大小才能在 TPU 上运行。
add#
我们修改了 add_kernel 示例,使其使用 BlockSpecs 操作(2,)大小的块。
def add_kernel(x_ref, y_ref, o_ref):
# In this code, `x_ref`, `y_ref` and `o_ref` are (2,)-shaped `Ref`s
x = x_ref[:]
y = y_ref[:]
o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(
add_kernel,
out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
in_specs=[
pl.BlockSpec((2,), lambda i: i),
pl.BlockSpec((2,), lambda i: i)
],
out_specs=pl.BlockSpec((2,), lambda i: i),
grid=(4,))
add(x, y)
模板化的矩阵乘法#
在此示例中,我们通过对输入数组的行和列块进行展开累加来计算输出块。我们将一个激活函数内联到内核主体中,使用高阶函数,以便发出一个融合内核。
def matmul_kernel(x_ref, y_ref, o_ref, *, activation, block_k):
acc = jnp.zeros((x_ref.shape[0], y_ref.shape[1]), jnp.float32)
for k in range(x_ref.shape[1] // block_k):
x = x_ref[:, k*block_k:(k+1)*block_k]
y = y_ref[k*block_k:(k+1)*block_k, :]
acc += x @ y
o_ref[:, :] = activation(acc).astype(o_ref.dtype)
x, y = jnp.ones((512, 256)), jnp.ones((256, 1024))
block_shape = 128, 256, 128
@partial(jax.jit, static_argnames=["block_shape", "activation"])
def matmul(x, y, *, block_shape, activation):
block_m, block_n, block_k = block_shape
fused_matmul = pl.pallas_call(
partial(matmul_kernel, block_k=block_k, activation=activation),
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32),
in_specs=[
pl.BlockSpec((block_m, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], block_n), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec((block_m, block_n), lambda i, j: (i, j)),
grid=(4, 4),
)
return fused_matmul(x, y)
z = matmul(x, y, block_shape=block_shape, activation=jax.nn.gelu)
Pallas 的降低(Lowering)#
在用户表达了他们的 Pallas 内核之后,我们将根据目标后端将其降低到不同的表示。在 GPU 上,我们将 Pallas 降低到 Triton IR,而在 TPU 上,我们将 Pallas 降低到 Mosaic。
将 Pallas 降低到 Triton 以用于 GPU#
将 Pallas 降低到 Triton 是容易的,因为 Pallas 最初是以 Triton 为目标语言设计的。Pallas 和 Triton 的主要区别在于,Triton 没有 BlockSpecs 的概念,并且在内存读写时使用指针而不是索引。
Triton 在其语言中支持指针作为数组元素类型,在 Triton 中你可以从指针数组加载并存储到指针数组。在 Pallas 中,当给定一个形状为 (4, 5) 的 Ref x_ref,然后进行 x_ref[3, 2] 操作时,我们需要将其降低为计算指向 x_ref 中相应行主序位置的 Triton 指针(即,进行 5 * 3 + 2 * 1)。类似地,当我们把切片降低到 Triton 时,例如 x_ref[4, :],我们需要生成一个指针数组 5 * 4 + jnp.arange(3)。
除此之外,降低到 Triton 相当直接。JAX 点积可以降低为 Triton 点积,JAX 一元原语被降低为它们的 Triton 等价物。Triton 的原子操作通过新的 Pallas 原子原语进行降低。
将 Pallas 降低到 Mosaic 以用于 TPU#
Mosaic (主要)消费标准的 MLIR 方言并发出 LLO 以供 TPU 编译。Pallas 可以通过将 JAX 原语翻译成 MLIR(主要是 vector 和 arith 方言)来降低到 Mosaic。BlockSpecs 可以转换为流水线调度(即 Mosaic 中的 transform_funcs)。
转换 Pallas#
一个自然的问题是 JAX 转换如何与 Pallas 内核交互?有两种主要方式:Pallas 内核内的转换和 Pallas 内核外的转换。
Pallas 内核内的转换实际上应该“就能工作”,只要我们能够降低转换后的代码。例如,我们可以在 JAX 内核中使用 jax.grad(jnp.sin)(...),因为我们可以将 cos 降低到 Triton 和 Mosaic。但是,我们可能无法降低 jax.vmap(lax.dynamic_slice),因为它可能会变成一个我们无法降低的 gather 操作。
从外部 JAX 程序转换 Pallas 内核可能是更有趣的情况。我们如何处理 vmap(pallas_call) 和 grad(pallas_call) 这样的事情?
vmap 后的 pallas_call#
vmap 自动向量化 JAX 程序。虽然内核编写者可能希望精确控制批量内核与其非批量变体之间的行为差异,但我们可以为 pallas_call 提供一个合理的默认 vmap 规则,同时提供 jax.custom_vmap 定制机制。当 pallas_call 被 vmap 时,我们为 pallas_call 增加一个额外的网格维度来对应新的批次维度,并转换 BlockSpecs 以处理沿着该维度的索引。
grad 后的 pallas_call#
pallas_call 的 grad 支持内核的自动微分。jax.grad 分解为三个独立转换的应用:jvp、partial_eval 和 transpose。原则上,我们在为 pallas_call 实现这些规则时,可以重用 JAX 的大部分基础设施(因为它与现有的 JAX 高阶原语非常相似)。
然而,内核的自动微分可能会由于内存访问的转置方式而导致性能下降。如果我们用重叠并行读取和不重叠并行写入的 GPU 内核,我们会自动将其转置为一个具有重叠并行写入(原子操作时很慢)和不重叠并行读取的内核。为了发出一个能更好地利用共享内存并行性的内核,我们需要重新排序循环并改变内核的向量化方式。不幸的是,Pallas 中没有一个合适的程序表示来进行这样的操作。自动高效微分内核的一个潜在方向是探索不同的表示,也许是像 Dex 中的那种。我们也可以看看 Enzyme 如何解决这个问题。然而,Pallas 内核的 AD 对于一类能够高效转置的内核(例如元素级内核)仍然可能有用。
但总的来说,jax.custom_vjp 是一个可行的逃生舱口,用于表达与 jax.grad 协同工作的 Pallas 内核。
其他转换#
我们可以设想其他 JAX 转换应用于我们尚未明确探索过的 Pallas 内核。例如,checkify 是一个进行函数式错误处理的 JAX 转换。我们可以设想使用 checkify 和 pallas_call 来允许从 GPU 内核中导出指示 OOB 访问或 NaN 产生的错误代码。
另一个潜在的集成转换是 custom_partitioning,以支持自动分区内核与 pjit 一起使用。