Pallas 设计#

本文档阐述了最初的 Pallas 设计。这是早期某些设计决策的快照,Pallas 的特定 API 此后可能已发生变化。

简介#

JAX 被广泛用于各类工作负载,从大规模机器学习到科学计算。JAX 的成功同样也是 XLA(JAX 所针对的主要编译器)的成功——XLA 为加速器编译 JAX 程序,并使 JAX 能够扩展到最大的机器学习模型。JAX 以 XLA 的表示形式 HLO 来描述逻辑计算。HLO 描述了计算在逻辑上如何发生,而非物理上如何发生。给定一个逻辑 HLO 计算,XLA 会决定该计算应如何在物理上执行。对于各种机器学习应用,XLA 在编译用户程序方面做得很好,但用户不可避免地会遇到 XLA 的局限性。在这种情况下,我们需要提供一个“逃生舱”(escape hatch),允许专家编写手工调优的内核,使其在特定场景下优于 XLA。此外,机器学习系统的研究进展往往需要一段时间才能被并入 XLA,而用户通常希望提前使用它们。随着时间的推移,编译器可以吸收那些通过手工内核实验证明有效的优化。

XLA 确实提供了 CustomCall 机制作为逃生舱,但这需要用户编写 C++,而在 GPU 上还要求用户学习 CUDA 编程模型。CUDA 编程模型对于许多机器学习 GPU 内核(如矩阵乘法)来说可能过于底层,即使是专家用户也很难使用 CUDA 实现高效的矩阵乘法或多头注意力机制。不仅如此,JAX 用户通常熟悉 Python 和 NumPy 风格的数组编程,这并不涉及编写 C++ 或考虑 GPU 并行性。所有流行的机器学习框架都共享这一理念:通过像 matmulconvolution 这样的高级操作来操纵(通常是)数组。不幸的是,这意味着通过 CustomCall 实现自定义操作是一项巨大的投入,可能涉及学习 C++ 和/或 GPU 编程。

Triton 是由 OpenAI 构建和维护的 GPU 编译器,它席卷了机器学习编译器界。Triton 兼具两全之美:一种用于 GPU 内核的基于数组的编程模型。Triton 是 PyTorch 2.0 中通过 Torch Inductor 库实现 torch.compile 的主要代码生成路径。Triton 主动隐藏了 GPU 编程的某些方面,以提供一种更易用的编程模型,既可以从 Python 调用,又能从更高级别的表示生成优化代码。虽然 GPU 的灵活性远超 Triton 所能提供的范围,但在机器学习领域,Triton 对许多应用而言已足够具有表达力。

在本文档中,我们将介绍 Pallas。它是 JAX 的一个扩展,使用户能够使用类似于 Triton 的模型为 GPU 和 TPU 编写内核编程。基于 JAX 的内核语言具有以下几点优势:

  • 尽管 Triton 向用户公开了一种类似 TPU 的编程模型(即为 L1 缓存中的数组分块编写程序),但它专门针对 GPU,因此我们无法直接将 Triton 编译为 TPU。例如,Triton 提供的原子操作是专门为了处理并行写入而设计的,这在 TPU 上并不一定适用。更高级的前端可以抽象化平台的细节,同时仅呈现这种基于分块(tile-based)的编程模型。这样,内核就可以在不同的硬件平台之间移植。

  • 作为一种基于追踪(tracing)的数值计算前端,JAX 既成熟又被广泛使用。通过将内核编程语言嵌入 JAX 自身,我们可以重用 JAX 的追踪基础设施,并提供用户已经熟悉的类似 NumPy 的前端。

  • JAX 转换是其成功的关键,允许用户编写简单的程序,通过转换来实现复杂的功能。我们可以利用相同的转换(vmap、jvp 等)来转换用户编写的内核。

悬而未决的问题是:JAX 是否真的适合作为内核语言?我们认为是。Triton 证明了数组编程语言对于编写 GPU 内核是实用的,而 JAX 正是这样一种语言。同时,JAX 已被证明是编译器和程序转换的灵活前端。

我们将按照以下方式介绍 Pallas:首先介绍我们扩展 JAX 以支持编写自定义内核的方法;然后展示如何将 Pallas Lowering 至 Triton 和 Mosaic;最后总结通过 JAX 转换来转换 Pallas 内核的现有和潜在方式。

Pallas lowering path Pallas Lowering 路径可视化

Pallas:用于内核的 JAX 扩展#

我们想要强调的关键点是:Pallas 就是 JAX,只不过带有一些扩展。

  1. 用户现在可以在 JAX 代码中使用名为 Ref 的引用类型。这让用户能够更精确地控制内存访问,且 JAX 中的布局将更接近物理布局。

  2. 用户使用 JAX 原语的子集以及一组 Pallas 特定的原语来编写 JAX 程序。

  3. 用户通过一个特殊的 pallas_call 高阶函数将 Pallas 内核嵌入到外部 JAX 程序中,该函数在映射(map)中执行内核。它类似于 pmapshard_map,只是带有指向共享内存的引用。

我们将通过示例逐一介绍这三个扩展。

请注意,这些 API 仍处于实验阶段,可能会有变动。

引用类型#

让我们来看一个用于向量相加的 Pallas 程序示例。

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

def add_kernel(x_ref, y_ref, o_ref):
  # In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
  x = x_ref[:]
  y = y_ref[:]
  o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
add(x, y)

与常规 JAX 程序不同,add_kernel 不接收不可变数组参数。相反,它接收可以使用 NumPy 风格语法进行读取和原地(in-place)更新的引用。Ref 并非 Pallas 独有的概念——它们被引入 JAX 是为了表示有状态的计算。但是,我们可以在编写操作可变内存的内核时利用它们。

Pallas 内核不仅接收对应于内核输入的 Ref,还接收对应于输出的 Ref(在 pallas_call 中通过 out_shape 指定)。Ref 是一种特殊类型,在读取之前不能传入通常的 JAX 原语。当您从 Ref 读取时,会得到一个 JAX Array 类型,而您必须将一个 Array 写入到 Ref 中。

Ref 的读取与写入#

Ref 读取对应于将数组加载到内存层级结构的最低层(GPU 上的 L1 缓存,TPU 上的向量寄存器)。写入 Ref 的逻辑也是类似的。

def f(x_ref, o_ref):
  # Using vanilla Python indexing
  x = x_ref[0, 2:5, :]
  # Or via Numpy advanced int indexing
  o_ref[jnp.arange(3), :] = x

# Note that in order to use NumPy advanced int indexing, you need to broadcast the indices against each other into the desired multidimensional shape:
def f(x_ref):
  # Assume x_ref is (8, 4) and we want to read out a (2, 3) slice
  x = x_ref[jnp.arange(2)[..., None], jnp.arange(3)[None, ...]]

Ref 的写入可以通过类似的 __setitem__ 风格索引来完成。

其他形式的索引(例如动态切片)可以通过 pallas.loadpallas.store 完成,这些是旨在简化内存加载/存储的新 JAX 原语。稍后我们将讨论这些新原语。

使用新的 Pallas 原语扩展 JAX#

由于 JAX 在设计时考虑了 HLO,JAX 原语集紧密映射了 HLO 操作集。针对新的编译器(例如 Triton 或 Mosaic)意味着我们可能需要用特定于新编译器的新原语来补充 JAX 原语。同时,我们可能无法 Lowering 所有 JAX 原语,因此需要将其限制为一个子集。

由于 Pallas 最初设计时考虑到了 Triton,我们提供了一组针对 Triton 编程模型的新原语。正如稍后将展示的,我们也可以将这些原语 Lowering 至 Mosaic。

pallas.loadpallas.store#

pallas.loadpallas.store 是允许从内存加载和向内存存储的原语。与 __getitem____setitem__ 不同,它们更灵活,代价是更冗长。具体而言,您可以使用 pallas.dynamic_slice(简称 pallas.ds)构造(或许应该将其上游集成进 JAX,以供 Ref 的 __getitem____setitem__ 使用)。

def f(x_ref, o_ref):
  # Reading from memory via pallas.load
  x = pl.load(x_ref, (0, slice(2, 5), slice(None)))
  # Using integer indexing automatically broadcasts
  x = pl.load(x_ref, (0, 2 + jnp.arange(3), slice(None)))
  # You can also use `pl.dynamic_slice` (`pl.ds` for short) objects as well
  pl.store(o_ref, (0, pl.ds(start=2, size=3), slice(None)), x)

pallas.loadpallas.store 还支持通过 mask 参数进行遮罩处理。

def f(x_ref, o_ref):
  # Reading from memory via pallas.load
  idx = jnp.arange(8)
  mask = idx < 5
  x = pl.load(x_ref, (idx,), mask=mask, other=float('-inf'))

在进行越界加载/存储时,遮罩非常重要。遮罩的操作语义可以由编译器决定(如果我们正确理解了文档,Triton 会在有遮罩的情况下避免对内存进行读写)。

pallas.program_idpallas.num_programs#

正如我们将要看到的,我们将多次执行相同的 Pallas 内核(根据后端不同,可能是并行执行或流水线执行)。这些新原语告诉我们内核在执行过程中的“位置”。

pallas.program_id 接受一个 axis 参数,它告诉我们该内核当前正在多维网格的哪个轴索引中执行(类似于 CUDA 编程中的 threadIdjax.pmap 中的 lax.axis_index)。请注意,我们目前借用了 Triton 的“program”术语,未来可能希望将其更改为 JAX 用户更熟悉的名称。

def f(x_ref, o_ref):
  i = pl.program_id(axis=0)  # execution index in the first axis of the grid
  o_ref[i] = jnp.exp(x_ref[i])

pallas.num_programs 同样接受一个轴并返回该轴的网格大小。

请注意,虽然 program_idnum_programs 是 Triton 特有的术语,但它们很容易被推广,使其在 TPU 上也具有意义。

在 Pallas 中使用 JAX 原语子集#

由于我们编写的是内核,而不是高级 HLO 程序,某些 JAX 原语可能无法在底层基质中高效表示。不过,我们知道我们可以支持大多数逐元素(elementwise)操作、简单的点积以及 JAX 控制流。

虽然我们还没有完全列出所有可以在 Pallas 内核中支持的 JAX 原语,但我们肯定能识别出一些难以 Lowering 或不太有用的原语:

  • conv_general - 卷积通常不作为底层硬件的原语提供。

  • gather/scatter - 底层编译器可能不支持非连续内存的读写。

使用 pallas_call 执行 Pallas 内核#

既然我们已经编写了 Pallas 内核(即带有 Ref 和额外 Pallas 原语的 JAX 代码),我们该如何在 GPU 或 TPU 上执行它们呢?我们使用 pallas_call,这是一个高阶函数(类似于 jax.jitjax.pmap),用于执行内核。

pallas_call 的签名如下:

def pallas_call(
    kernel: Callable,
    out_shape: Sequence[jax.ShapeDtypeStruct],
    *,
    in_specs: Sequence[Spec],
    out_specs: Sequence[Spec],
    grid: Optional[Tuple[int, ...]] = None) -> Callable:
  ...

当我们向 pallas_call 提供内核时,会提供额外的信息。首先是 out_shape,它告诉内核输出的样子(pallas_call 会将对应的 Ref 传入内核以供写入)。其余信息(in_specsout_specsgrid)是关于内核如何在加速器上调度的信息。

pallas_call 的(大致)语义如下:

def pallas_call(kernel, out_shape, *, in_specs, out_specs, grid):
  def execute(*args):
    outputs = map(empty_ref, out_shape)
    grid_indices = map(range, grid)
    for indices in itertools.product(*grid_indices): # Could run in parallel!
      local_inputs = [in_spec.transform(arg, indices) for arg, in_spec in
                      zip(args, in_specs)]
      local_outputs = [out_spec.transform(arg, indices) for arg, out_spec  in
                       zip(outputs, out_specs)]
      kernel(*local_inputs, *local_outputs) # writes to outputs
  return execute

具体来说,pallas_call 会在网格迭代空间上进行“循环”,应用通过 in_specsout_specs 指定的对输入和输出的转换。在每次迭代中,内核将在转换后的输入和输出上被调用。请注意,网格空间上的“循环”可以并行执行(例如在 GPU 上)。pallas_call 也不保证网格空间迭代的顺序,仅保证迭代空间中的每个成员都会被遍历到。像 Triton 和 Mosaic 这样的编译器会关联更具体的网格操作语义。

转换函数#

pallas_callin_specsout_specs 参数允许以某种方式转换输入和输出。Pallas 目前提供的两种选择是恒等转换(输入和输出保持不变)以及 BlockSpec,它根据循环索引获取 Ref 的固定大小切片。

一个 BlockSpec 接受一个 index_map 函数和一个 block_shape。从逻辑上讲,它获取一个数组,并沿着每个轴将其切片为 block_shape 大小的块。index_map 函数获取循环索引(来自网格索引集)并将它们映射到块索引。转换函数将 Ref 转换为对应块中 Ref 的逻辑视图。当我们在 block_shape 中指定 None 时,对应于在那个维度上进行“映射”,并将其从内核内部的块中移除。

class BlockSpec:
  index_map: Callable[[Tuple[Int, ...]], Tuple[Int, ...]]
  block_shape: Tuple[Optional[int], ...]

  def transform(self, ref, *loop_indices):
    block_indices = self.transform_function(loop_indices)
    # Returns a view of `ref` starting at `block_indices` of shape self.block_shape
    ...

我们也可以想象其他与 pallas_call 一起使用的 Spec,例如对应于重叠窗口的 Spec,以便实现卷积等功能。

作为前端,Pallas 的即时收益#

通过提供用于编写内核的 JAX 前端,我们可以立即获得一些收益。

更灵活的前端#

首先,JAX 用户已经习惯了使用 JAX 及其基于追踪的转换进行编程所带来的好处(和局限性)。这意味着用户在编写 Pallas 内核时可以使用闭包和其他熟悉的 Python 构造。这不同于现有的基于 AST 解析的 Triton 前端或 Mosaic 的 MLIR 构建器。例如,这使得 Pallas 比 Triton 更适合模板化。

参考此示例,了解我们如何使用 Python 中的高阶函数来模板化内核。

def make_kernel(eltwise_kernel):
  def add(x_ref, y_ref, o_ref):
    x = pl.load(x_ref, ())
    y = pl.load(y_ref, ())
    pl.store(o_ref, (), eltwise_kernel(x + y))
  return add

kernel1 = make_kernel(lambda x: x * 2)
kernel2 = make_kernel(jnp.exp)

pl.pallas_call(kernel1, out_shape=x, grid=1)(1., 1.)
pl.pallas_call(kernel2, out_shape=x, grid=1)(1., 1.)

仿真模式#

通过将内核表示为带有 JAX 原语和一些新 Pallas 原语的程序,我们还可以直接将 Pallas 程序 Lowering 至 StableHLO,并使用 XLA 进行编译/执行。具体来说,pallas_call 可以实现为网格上的 lax.scan。这使我们能够在任何支持 XLA 的平台上(甚至 CPU 上!)开发 GPU 或 TPU 内核,并使用 JAX/XLA 调试工具(如 jax.debug.print)进行调试。我们还可以使用更可靠且经过更好测试的 XLA 数值来验证 Triton 和 Mosaic 编译器的正确性。人们甚至可以想象通过扰动 scan 的顺序来模拟 GPU 上发生的并行读写。

GPU 示例#

请注意,以下所有示例仅针对 GPU。若要在 TPU 上工作,需要对块大小进行调整。

add#

我们修改了 add_kernel 示例,使其使用 BlockSpec 在 (2,) 大小的块上进行操作。

def add_kernel(x_ref, y_ref, o_ref):
  # In this code, `x_ref`, `y_ref` and `o_ref` are (2,)-shaped `Ref`s
  x = x_ref[:]
  y = y_ref[:]
  o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(
    add_kernel,
    out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
    in_specs=[
        pl.BlockSpec((2,), lambda i: i),
        pl.BlockSpec((2,), lambda i: i)
    ],
    out_specs=pl.BlockSpec((2,), lambda i: i),
    grid=(4,))
add(x, y)

模板化矩阵乘法#

在此示例中,我们通过对输入数组的行块和列块执行展开累加(unrolled accumulation)来计算输出的瓦片(tiles)。我们使用高阶函数将激活函数内联到内核体中,从而可以发出融合内核(fused kernel)。

def matmul_kernel(x_ref, y_ref, o_ref, *, activation, block_k):
  acc = jnp.zeros((x_ref.shape[0], y_ref.shape[1]), jnp.float32)
  for k in range(x_ref.shape[1] // block_k):
    x = x_ref[:, k*block_k:(k+1)*block_k]
    y = y_ref[k*block_k:(k+1)*block_k, :]
    acc += x @ y
  o_ref[:, :] = activation(acc).astype(o_ref.dtype)

x, y = jnp.ones((512, 256)), jnp.ones((256, 1024))
block_shape = 128, 256, 128

@partial(jax.jit, static_argnames=["block_shape", "activation"])
def matmul(x, y, *, block_shape, activation):
  block_m, block_n, block_k = block_shape
  fused_matmul = pl.pallas_call(
      partial(matmul_kernel, block_k=block_k, activation=activation),
      out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32),
      in_specs=[
          pl.BlockSpec((block_m, x.shape[1]), lambda i, j: (i, 0)),
          pl.BlockSpec((y.shape[0], block_n), lambda i, j: (0, j))
      ],
      out_specs=pl.BlockSpec((block_m, block_n), lambda i, j: (i, j)),
      grid=(4, 4),
  )
  return fused_matmul(x, y)

z = matmul(x, y, block_shape=block_shape, activation=jax.nn.gelu)

Pallas 的 Lowering#

在用户表达 Pallas 内核后,我们会根据目标后端将其 Lowering 为不同的表示。在 GPU 上,我们将 Pallas Lowering 至 Triton IR;在 TPU 上,我们将 Pallas Lowering 至 Mosaic。

将 Pallas Lowering 至 GPU 的 Triton#

将 Pallas Lowering 至 Triton 很简单,因为 Pallas 的设计初衷就是将 Triton 作为目标语言。Pallas 和 Triton 的主要区别在于,Triton 没有 BlockSpec 的概念,并且在进行内存加载和存储时使用指针,而不是索引。

Triton 在其语言中支持指针作为数组元素类型,并且在 Triton 中,您可以从指针数组加载和存储。在 Pallas 中,当给定一个形状为 (4, 5)Ref (x_ref) 并执行类似 x_ref[3, 2] 的操作时,我们需要将其 Lowering 为计算一个指向 x_ref 中相应行主序位置的 Triton 指针(即执行 5 * 3 + 2 * 1)。同样,当我们对切片进行 Lowering 至 Triton 时,例如 x_ref[4, :],我们需要生成一个指针数组 5 * 4 + jnp.arange(3)

除此之外,Lowering 至 Triton 相当直接。JAX 点积可以 Lowering 为 Triton 点积,JAX 一元原语可以 Lowering 为其对应的 Triton 等价物。Triton 的原子操作通过新的 Pallas 原子原语进行 Lowering。

将 Pallas Lowering 至 TPU 的 Mosaic#

Mosaic(主要)消耗标准的 MLIR 方言并发出 LLO,以编译为 TPU。Pallas 可以通过将 JAX 原语转换成 MLIR(主要是 vectorarith 方言)从而 Lowering 至 Mosaic。BlockSpec 可以转换为流水线调度(即 Mosaic 中的 transform_func)。

转换 Pallas#

一个自然的问题是:JAX 转换如何与 Pallas 内核交互?主要有两种方式:Pallas 内核内部的转换和 Pallas 内核外部的转换。

只要我们能够 Lowering 转换后的代码,Pallas 内核内部的转换应该“直接工作”。例如,我们可以在 JAX 内核中使用 jax.grad(jnp.sin)(...),因为我们可以将 cos Lowering 至 Triton 和 Mosaic。然而,我们可能无法 Lowering jax.vmap(lax.dynamic_slice),因为它可能变成我们无法处理的 gather 操作。

从外部 JAX 程序对 Pallas 内核进行转换可能是一个更有趣的情况。我们如何处理 vmap(pallas_call)grad(pallas_call) 之类的东西?

vmap-of-pallas_call#

vmap 会自动向量化 JAX 程序。虽然内核编写者可能希望精确控制分批内核(batched kernel)如何与其未分批变体表现不同,但我们可以为 pallas_call 提供一个合理的默认 vmap 规则,同时提供 jax.custom_vmap 自定义机制。当 pallas_callvmap 时,我们会扩充 pallas_call 以使其拥有一个对应于新批处理维度的额外网格维度,并转换 BlockSpec 以处理沿该维度的索引。

grad-of-pallas_call#

pallas_callgrad 实现了内核的自动微分。jax.grad 分解为三种不同转换的应用:jvppartial_evaltranspose。原则上,当为 pallas_call 实现这些规则时,我们可以重用大部分 JAX 基础设施(因为它在很大程度上类似于现有的 JAX 高阶原语)。

然而,由于内存访问转置的方式,内核的自动微分可能会导致性能下降。如果我们编写了一个具有重叠并行读取和不相交并行写入的 GPU 内核,我们会自动将其转置为一个具有重叠并行写入(原子方式完成时会很慢)和不相交并行读取的内核。为了发出一个能更好地利用共享内存并行性的内核,我们需要重排循环并改变内核向量化的方式。不幸的是,在 Pallas 中我们没有适合这样做的程序表示。有效自动微分内核的一个潜在方向是探索不同的表示,也许类似于 Dex 中的那种。我们也可以看看 Enzyme 是如何处理这个问题的。不过,Pallas 内核的 AD 对于某些能够高效转置的内核类别(例如逐元素内核)可能仍然有用。

总的来说,jax.custom_vjp 是一个可行的逃生舱,用于表达与 jax.grad 协同工作的 Pallas 内核。

其他转换#

我们可以设想应用到 Pallas 内核上的其他 JAX 转换,我们尚未明确探索。例如,checkify 是一种进行函数式错误处理的 JAX 转换。我们可以设想将 checkifypallas_call 一起使用,以允许从 GPU 内核中输出指示是否产生 OOB(越界)访问或 NaN 的错误代码。

另一个可以集成的潜在转换是 custom_partitioning,以使可自动分区的内核能够与 pjit 一起使用。