Pallas 设计#
在本文档中,我们解释了 Pallas 的初始设计。这是早期设计决策的快照,Pallas 的具体 API 可能自那时以来已更改。
介绍#
JAX 被用于各种工作负载,从大规模机器学习到科学计算。JAX 的成功故事在很大程度上也是 XLA 的成功故事,XLA 是 JAX 主要的目标编译器——XLA 为加速器编译 JAX 程序,并使 JAX 能够扩展到最大的 ML 模型。JAX 在 XLA 的表示 HLO 中描述逻辑计算。HLO 描述了计算在逻辑上如何发生,而不是物理上如何发生。给定一个逻辑 HLO 计算,XLA 决定如何物理执行该计算。对于各种 ML 应用程序,XLA 在编译用户程序方面做得很好,但不可避免地,一些用户会遇到 XLA 的限制。在这些情况下,我们需要提供一个“逃生舱口”,允许专家编写手调内核,以便在特定时间点超越 XLA 的性能。此外,ML 系统研究的进展需要一些时间才能纳入 XLA,而用户通常希望提前使用它们。随着时间的推移,编译器可以将通过手动调整内核实验证明的优化纳入其中。
XLA 确实提供了 CustomCall
机制作为逃生舱口,但这需要用户编写 C++ 代码,并且在 GPU 上需要用户学习 CUDA 编程模型。CUDA 编程模型对于许多机器学习 GPU 内核(如矩阵乘法)来说,可以说是太底层了,即使是专家用户也会难以使用 CUDA 来实现高效的矩阵乘法或多头注意力。不仅如此,JAX 用户通常熟悉 Python 和 NumPy 风格的数组编程,这不涉及编写任何 C++ 代码或考虑 GPU 并行性。所有流行的机器学习框架都秉持这个理念:使用高级操作(如 matmul
或 convolution
)来操作(通常是)数组。不幸的是,这意味着通过 CustomCall
实现自定义操作是一项巨大的投入,可能需要学习 C++ 和/或 GPU 编程。
Triton 是一个由 OpenAI 构建和维护的 GPU 编译器,已经在 ML 编译器领域引起了轰动。Triton 提供了两全其美的方案:用于 GPU 内核的基于数组的编程模型。Triton 是 PyTorch 2.0 中 torch.compile
的主要代码生成途径,通过 Torch Inductor 库。Triton 积极地隐藏了 GPU 编程的某些方面,目的是为了提供一种更易于访问的编程模型,该模型可以从 Python 中使用,并从更高级别的表示生成优化的代码。虽然 GPU 比 Triton 提供的更灵活,但在 ML 领域,Triton 似乎对于许多应用来说已经足够表达了。
在本文档中,我们描述了 Pallas,它是 JAX 的一个扩展,可以使用类似 Triton 的模型为 GPU 和 TPU 启用内核编程。基于 JAX 的内核语言具有以下几个优势
虽然 Triton 向用户公开了类似 TPU 的编程模型,即为 L1 缓存中数组的切片编写程序,但它对 GPU 的专业化程度很高,以至于我们无法直接为 TPU 编译 Triton。例如,Triton 提供了专门用于处理并行写入的原子操作,这些操作在 TPU 上不一定有意义。更高级别的前端可以抽象出平台的细节,同时仅呈现基于切片的编程模型。因此,内核将可以在不同的硬件平台上移植。
JAX 作为数值计算的基于跟踪的前端,既成熟又被广泛使用。通过将内核编程语言嵌入到 JAX 本身中,我们可以重用 JAX 的跟踪基础设施,并提供一个用户已经熟悉的类似 NumPy 的前端。
JAX 转换是其成功的关键,它允许用户表达简单的程序,但可以转换它们以实现复杂的功能。我们可以利用相同的转换(vmap、jvp 等)来转换用户编写的内核。
开放性问题是:JAX 真的适合作为内核语言吗?我们认为是。Triton 证明,数组编程语言对于编写 GPU 内核是切实可行的,而 JAX 正是如此。JAX 也已被证明是编译器和程序转换的灵活前端。
我们如下描述 Pallas:我们首先描述扩展 JAX 以支持编写自定义内核的方式。然后,我们展示如何将 Pallas 降低到 Triton 和 Mosaic。最后,我们描述通过 JAX 转换转换 Pallas 内核的现有和潜在方法。
Pallas 降低路径的可视化
Pallas:扩展 JAX 以用于内核#
我们想要强调的关键点是 Pallas 只是 JAX,带有一些扩展
用户现在在其 JAX 代码中使用名为
Ref
s 的引用类型。这让用户可以更精确地控制内存访问,并且 JAX 中的布局将更接近物理布局。用户使用 JAX 原语的子集以及一组 Pallas 特定的原语来编写他们的 JAX 程序。
用户通过一个特殊的
pallas_call
高阶函数将他们的 Pallas 内核嵌入到外部 JAX 程序中,该函数在映射中执行内核。它类似于pmap
或shard_map
,但使用了对共享内存的引用。
我们将通过示例逐一介绍这三个扩展。
请注意,这些 API 仍处于实验阶段,可能会发生变化。
引用类型#
让我们看一个用于添加两个向量的 Pallas 程序示例
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
def add_kernel(x_ref, y_ref, o_ref):
# In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
x = x_ref[:]
y = y_ref[:]
o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
add(x, y)
与常规 JAX 程序不同,add_kernel
不接收不可变的数组参数。相反,它提供了可以使用类似 NumPy 语法进行读取和就地更新的引用。Ref
s 不是 Pallas 特有的概念——它们被引入 JAX 以表示有状态的计算。但是,在编写操作可变内存的内核时,我们也可以利用它们。
Pallas 内核不仅接收与内核输入相对应的 Ref
s,还接收输出的 Ref
s(在 pallas_call
中通过 out_shape
指定)。Ref
s 是特殊类型,在先读取之前,不能传递到通常的 JAX 原语集中。当您从 Ref
中读取时,您会得到一个 JAX Array
类型,并且您必须将一个 Array
写入 Ref
。
从 Refs 读取/写入 Refs#
从 Ref
读取对应于将数组加载到内存层次结构的最底层(GPU 上的 L1 缓存和 TPU 上的向量寄存器)。写入 Ref
是类似的。
def f(x_ref, o_ref):
# Using vanilla Python indexing
x = x_ref[0, 2:5, :]
# Or via Numpy advanced int indexing
o_ref[jnp.arange(3), :] = x
# Note that in order to use NumPy advanced int indexing, you need to broadcast the indices against each other into the desired multidimensional shape:
def f(x_ref):
# Assume x_ref is (8, 4) and we want to read out a (2, 3) slice
x = x_ref[jnp.arange(2)[..., None], jnp.arange(3)[None, ...]]
写入 Ref
s 可以通过类似的 __setitem__
风格的索引来完成。
其他形式的索引(例如,动态切片)可以通过 pallas.load
和 pallas.store
完成,它们是新的 JAX 原语,旨在使从内存加载/存储到内存更容易。我们稍后将讨论这些新的原语。
使用新的 Pallas 原语扩展 JAX#
因为 JAX 的设计考虑了 HLO,所以 JAX 原语集与 HLO 操作集密切相关。针对一个新的编译器(例如 Triton 或 Mosaic)意味着我们可能需要用特定于新编译器的原语来补充 JAX 的原语。同时,我们可能无法降低所有 JAX 原语,因此我们需要将其限制为子集。
因为 Pallas 最初的设计考虑了 Triton,所以我们提供了一组针对 Triton 编程模型的新原语。正如我们稍后将展示的,我们也可以将这些原语降低到 Mosaic。
pallas.load
和 pallas.store
#
pallas.load
和 pallas.store
是允许从内存加载和存储到内存的原语。与 __getitem__
和 __setitem__
不同,它们更灵活,但代价是更冗长。具体来说,您可以使用 pallas.dynamic_slice
(简称 pallas.ds
)构造(这也许应该上游到 JAX 中,以便与 Ref __getitem__
和 __setitem__
一起使用)。
def f(x_ref, o_ref):
# Reading from memory via pallas.load
x = pl.load(x_ref, (0, slice(2, 5), slice(None)))
# Using integer indexing automatically broadcasts
x = pl.load(x_ref, (0, 2 + jnp.arange(3), slice(None)))
# You can also use `pl.dynamic_slice` (`pl.ds` for short) objects as well
pl.store(o_ref, (0, pl.ds(start=2, size=3), slice(None)), x)
pallas.load
和 pallas.store
也通过 mask 参数支持掩码。
def f(x_ref, o_ref):
# Reading from memory via pallas.load
idx = jnp.arange(8)
mask = idx < 5
x = pl.load(x_ref, (idx,), mask=mask, other=float('-inf'))
在进行越界加载/存储时,掩码非常重要。掩码的操作语义可以由编译器决定(如果我们正确理解文档,Triton 会避免在被掩码时从/向内存读取/写入)。
pallas.program_id
和 pallas.num_programs
#
正如我们很快将看到的,我们将多次执行相同的 Pallas 内核(并行或在流水线中,具体取决于后端)。这些新的原语告诉我们内核执行的“位置”。
pallas.program_id
接受一个 axis 参数,该参数告诉我们内核当前正在多维网格的轴的哪个索引中执行(类似于 CUDA 编程中的 threadId 或 jax.pmap 中的 lax.axis_index)。请注意,我们目前正在借用 Triton 的“program”术语,将来我们可能希望将其更改为 JAX 用户更熟悉的术语。
def f(x_ref, o_ref):
i = pl.program_id(axis=0) # execution index in the first axis of the grid
o_ref[i] = jnp.exp(x_ref[i])
pallas.num_programs
也接受一个 axis,并返回该轴的网格大小。
请注意,虽然 program_id 和 num_programs 是 Triton 特有的术语,但它们很容易推广到在 TPU 上也有意义。
在 Pallas 中使用 JAX 原语的子集#
因为我们正在编写内核,而不是高级 HLO 程序,所以某些 JAX 原语可能无法在我们底层基板中有效地表示。但是,我们知道我们可以支持大多数元素级操作、简单的点积和 JAX 控制流。
虽然我们尚未完全映射出我们可以在 Pallas 内核中支持的所有 JAX 原语,但我们当然可以识别出一些不容易降低或不太有用的原语
conv_general
- 卷积通常不作为底层硬件中的原语提供。gather/scatter
- 底层编译器可能不支持非连续内存的读取和写入
使用 pallas_call
执行 Pallas 内核#
现在我们已经编写了我们的 Pallas 内核(又名带有 Refs 和额外 Pallas 原语的 JAX),我们如何在 GPU 或 TPU 上执行它们?我们使用 pallas_call
,这是一个高阶函数(类似于 jax.jit
和 jax.pmap
),用于执行内核。
pallas_call
的签名如下
def pallas_call(
kernel: Callable,
out_shape: Sequence[jax.ShapeDtypeStruct],
*,
in_specs: Sequence[Spec],
out_specs: Sequence[Spec],
grid: Optional[Tuple[int, ...]] = None) -> Callable:
...
当我们向 pallas_call
提供内核时,我们会提供额外的信息。第一个是 out_shape
,它告诉内核输出是什么样的(pallas_call
将传递与这些输出相对应的 Ref
到内核中进行写入)。其余信息(in_specs
、out_specs
和 grid
)是关于内核如何在加速器上调度的信息。
pallas_call
的(粗略)语义如下
def pallas_call(kernel, out_shape, *, in_specs, out_specs, grid):
def execute(*args):
outputs = map(empty_ref, out_shape)
grid_indices = map(range, grid)
for indices in itertools.product(*grid_indices): # Could run in parallel!
local_inputs = [in_spec.transform(arg, indices) for arg, in_spec in
zip(args, in_specs)]
local_outputs = [out_spec.transform(arg, indices) for arg, out_spec in
zip(outputs, out_specs)]
kernel(*local_inputs, *local_outputs) # writes to outputs
return execute
具体来说,pallas_call
将在网格迭代空间上“循环”,对通过 in_specs
和 out_specs
指定的输入和输出应用转换。在每次迭代中,将在转换后的输入和输出上调用内核。请注意,迭代空间上的“循环”可以并行执行(例如在 GPU 上)。pallas_call
也不保证迭代空间上的循环迭代顺序,仅保证将循环访问迭代空间的每个成员。像 Triton 和 Mosaic 这样的编译器将具有与网格相关的更具体的操作语义。
转换函数#
pallas_call
的 in_specs
和 out_specs
参数允许以某种方式转换输入和输出。Pallas 现在提供的两个选项是恒等变换(输入和输出保持不变)和 BlockSpec
,它采用由循环索引确定的 Ref
的固定大小切片。
BlockSpec
接受一个 index_map
函数和一个 block_shape
。从逻辑上讲,它接受一个数组,并沿着每个轴将其切片成 block_shape
大小的块。index_map
函数接受循环索引(来自网格索引集)并将它们映射到块索引。转换函数将 Ref
转换为对应块处 Ref
的逻辑视图。当我们在 block_shape 中的条目中指定 None
时,这对应于在该维度上“映射”,从而将其从内核中的块中移除。
class BlockSpec:
index_map: Callable[[Tuple[Int, ...]], Tuple[Int, ...]]
block_shape: Tuple[Optional[int], ...]
def transform(self, ref, *loop_indices):
block_indices = self.transform_function(loop_indices)
# Returns a view of `ref` starting at `block_indices` of shape self.block_shape
...
我们还可以设想与 pallas_call
一起使用的其他 Spec
,例如,对应于重叠窗口的 Spec
,例如,用于实现卷积。
Pallas 作为前端的直接优势#
通过为内核编写提供 JAX 前端,我们可以立即获得一些好处。
更灵活的前端#
首先,JAX 用户已经习惯于使用 JAX 及其基于跟踪的转换进行编程的优势(和局限性)。这意味着用户在编写 Pallas 内核时可以使用闭包和其他熟悉的 Python 构造。这与现有的基于 AST 解析的 Triton 前端或 Mosaic 的 MLIR 构建器不同。例如,这使得 Pallas 比 Triton 更适合模板化。
请参阅此示例,了解我们如何在 Python 中使用高阶函数来模板化内核。
def make_kernel(eltwise_kernel):
def add(x_ref, y_ref, o_ref):
x = pl.load(x_ref, ())
y = pl.load(y_ref, ())
pl.store(o_ref, (), eltwise_kernel(x + y))
return add
kernel1 = make_kernel(lambda x: x * 2)
kernel2 = make_kernel(jnp.exp)
pl.pallas_call(kernel1, out_shape=x, grid=1)(1., 1.)
pl.pallas_call(kernel2, out_shape=x, grid=1)(1., 1.)
仿真模式#
通过将内核表示为带有 JAX 原语和一些新的 Pallas 原语的程序,我们还可以将 Pallas 程序直接降级到 StableHLO,并使用 XLA 编译/执行它们。具体来说,pallas_call
可以实现为在网格上进行的 lax.scan
。这使我们能够在任何 XLA 支持的平台(甚至 CPU!)上开发 GPU 或 TPU 内核,并使用 JAX/XLA 调试工具(如 jax.debug.print
)调试它们。我们还可以使用更可靠且经过更好测试的 XLA 数值来验证 Triton 和 Mosaic 编译器的正确性。人们还可以想象扰乱 scan
排序以模拟 GPU 上发生的并行读取和写入。
GPU 示例#
请注意,以下所有示例仅适用于 GPU。它们将需要调整块大小才能在 TPU 上工作。
add
#
我们修改了 add_kernel
示例,以使用 BlockSpec
在 (2,) 大小的块上运行。
def add_kernel(x_ref, y_ref, o_ref):
# In this code, `x_ref`, `y_ref` and `o_ref` are (2,)-shaped `Ref`s
x = x_ref[:]
y = y_ref[:]
o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(
add_kernel,
out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
in_specs=[
pl.BlockSpec((2,), lambda i: i),
pl.BlockSpec((2,), lambda i: i)
],
out_specs=pl.BlockSpec((2,), lambda i: i),
grid=(4,))
add(x, y)
模板化 matmul#
在此示例中,我们通过对来自输入数组的行和列块进行展开累加来计算输出的切片。我们使用高阶函数将激活函数内联到内核主体中,以便我们可以发出融合内核。
def matmul_kernel(x_ref, y_ref, o_ref, *, activation, block_k):
acc = jnp.zeros((x_ref.shape[0], y_ref.shape[1]), jnp.float32)
for k in range(x_ref.shape[1] // block_k):
x = x_ref[:, k*block_k:(k+1)*block_k]
y = y_ref[k*block_k:(k+1)*block_k, :]
acc += x @ y
o_ref[:, :] = activation(acc).astype(o_ref.dtype)
x, y = jnp.ones((512, 256)), jnp.ones((256, 1024))
block_shape = 128, 256, 128
@partial(jax.jit, static_argnames=["block_shape", "activation"])
def matmul(x, y, *, block_shape, activation):
block_m, block_n, block_k = block_shape
fused_matmul = pl.pallas_call(
partial(matmul_kernel, block_k=block_k, activation=activation),
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32),
in_specs=[
pl.BlockSpec((block_m, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], block_n), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec((block_m, block_n), lambda i, j: (i, j)),
grid=(4, 4),
)
return fused_matmul(x, y)
z = matmul(x, y, block_shape=block_shape, activation=jax.nn.gelu)
降低 Pallas#
在用户表达他们的 Pallas 内核之后,我们根据目标后端将它们降低为不同的表示形式。在 GPU 上,我们将 Pallas 降低为 Triton IR,在 TPU 上,我们将 Pallas 降低为 Mosaic。
将 Pallas 降低到 Triton 以用于 GPU#
将 Pallas 降低到 Triton 很简单,因为 Pallas 在设计时就考虑了 Triton 作为目标语言。Pallas 和 Triton 之间的主要区别在于 Triton 没有 BlockSpec
的概念,并且在进行内存加载和存储时使用指针而不是索引。
Triton 支持指针作为其语言中的数组元素类型,并且在 Triton 中,您可以从指针数组加载和存储到指针数组。在 Pallas 中,当给定一个 (4, 5)
形状的 Ref
,x_ref
,然后像 x_ref[3, 2]
这样操作时,我们需要将其降低为计算 Triton 指针到 x_ref
中的相应行优先位置(即,执行 5 * 3 + 2 * 1)。类似地,当我们降低切片到 Triton 时,例如 x_ref[4, :]
,我们需要生成指针数组 5 * 4 + jnp.arange(3)
。
除此之外,降低到 Triton 非常简单。JAX 点积可以降低为 Triton 点积,JAX 一元原语降低为它们的 Triton 等效项。Triton 的原子操作通过新的 Pallas 原子原语降低。
将 Pallas 降低到 Mosaic 以用于 TPU#
Mosaic 消耗(主要)标准方言 MLIR 并发出 LLO 以编译为 TPU。Pallas 可以通过将 JAX 原语转换为 MLIR(主要是 vector
和 arith
方言)来降低到 Mosaic。BlockSpec
可以转换为流水线调度(即 Mosaic 中的 transform_func
)。
转换 Pallas#
一个自然的问题是 JAX 转换如何与 Pallas 内核交互?主要有两种方式:Pallas 内核内部的转换和 Pallas 内核外部的转换。
Pallas 内核内部的转换实际上应该“正常工作”,只要我们能够降低转换后的代码即可。例如,我们可以在 JAX 内核内部使用 jax.grad(jnp.sin)(...)
,因为我们可以将 cos
降低到 Triton 和 Mosaic。但是,我们可能无法降低 jax.vmap(lax.dynamic_slice)
,因为它可能会变成我们无法降低的 gather。
来自外部 JAX 程序的 Pallas 内核的转换可能更有趣。我们如何处理诸如 vmap(pallas_call)
和 grad(pallas_call)
之类的事情?
vmap-of-pallas_call
#
vmap 自动向量化 JAX 程序。虽然内核编写者可能希望精确控制批处理内核的行为与其非批处理变体的行为有何不同,但我们可以为 pallas_call
提供合理的默认 vmap
规则,同时提供 jax.custom_vmap
自定义机制。当 pallas_call
被 vmap
化时,我们增强 pallas_call
以具有与新批处理维度相对应的额外网格维度,并转换 BlockSpec
以处理沿该维度的索引。
grad-of-pallas_call
#
pallas_call
的 grad
实现了内核的自动微分。jax.grad
分解为三个不同转换的应用:jvp
、partial_eval
和 transpose
。原则上,在为 pallas_call
实现这些规则时,我们可以重用 JAX 的大部分基础设施(因为它与现有的 JAX 高阶原语非常相似)。
然而,内核的自动微分可能会由于内存访问的转置方式而导致性能下降。如果我们编写一个具有重叠且并行读取和不相交但并行写入的 GPU 内核,我们会自动将其转置为具有重叠但并行写入(原子完成时速度很慢)和不相交且并行读取的内核。为了发出一个更好地利用共享内存并行性的内核,我们将需要重新排序循环并更改内核的向量化方式。不幸的是,我们在 Pallas 中没有适合此目的的程序表示。有效自动微分内核的一个潜在方向是探索不同的表示形式,可能类似于 Dex 中的表示形式。我们还可以研究 Enzyme 如何解决这个问题。但是,Pallas 内核的 AD 对于一类有效执行转置的内核(例如,逐元素内核)可能仍然有用。
但总的来说,jax.custom_vjp
是表达与 jax.grad
一起使用的 Pallas 内核的可行方法。
其他转换#
我们可以想象其他 JAX 转换应用于我们尚未明确探索的 Pallas 内核。例如,checkify
是一种执行功能性错误处理的 JAX 转换。我们可以想象将 checkify
与 pallas_call 一起使用,以允许从 GPU 内核中输出错误代码,这些错误代码指示是否产生了 OOB 访问或 NaN。
另一个可能集成的转换是 custom_partitioning,以使自动可分区内核能够与 pjit 一起使用。