Pallas 快速入门#
Pallas 是 JAX 的一个扩展,它支持为 GPU 和 TPU 编写自定义内核。Pallas 允许您使用相同的 JAX 函数和 API,但操作的抽象级别更低。
具体来说,Pallas 要求用户考虑内存访问以及如何将计算划分为硬件加速器上的多个计算单元。在 GPU 上,Pallas 会降低到 Triton;在 TPU 上,Pallas 会降低到 Mosaic。
让我们通过一些例子来深入了解。
注意:Pallas 仍是一个实验性 API,更改可能会破坏您的代码!
Pallas 中的 Hello World#
from functools import partial
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
我们将首先用 Pallas 编写“Hello World”,这是一个将两个向量相加的内核。
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
Ref
类型
让我们稍微剖析一下这个函数。与您可能编写过的大多数 JAX 函数不同,它不接受 jax.Array
作为输入,也不返回任何值。相反,它接受Ref
对象作为输入,这些对象表示内存中的可变缓冲区。请注意,我们也没有任何输出,但我们得到了一个 o_ref
,它对应于期望的输出。
从 Ref
读取
在函数体中,我们首先从 x_ref
和 y_ref
读取,这由 [...]
表示(省略号表示我们正在读取整个 Ref
;或者我们也可以使用 x_ref[:]
)。像这样从 Ref
读取会返回一个 jax.Array
。
写入 Ref
然后我们将 x + y
写入 o_ref
。突变在历史上并不受 JAX 支持——jax.Array
是不可变的!Ref
是新的(实验性)类型,在某些情况下允许突变。我们可以将写入 Ref
理解为修改其底层缓冲区。
所以我们已经编写了我们称之为“内核”的代码,我们将其定义为在加速器上作为原子执行单元运行的程序,而无需与主机进行任何交互。我们如何从 JAX 计算中调用它?我们使用 pallas_call
高阶函数。
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(
add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
add_vectors(jnp.arange(8), jnp.arange(8))
Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)
pallas_call
将 Pallas 内核函数提升为一个可以作为更大 JAX 程序的一部分调用的操作。但是,为了做到这一点,它需要一些额外的细节。这里我们指定 out_shape
,这是一个具有 .shape
和 .dtype
(或其列表)的对象。out_shape
决定了我们 add_vector_kernel
中 o_ref
的形状/数据类型。
pallas_call
返回一个接受和返回 jax.Array
的函数。
这里实际发生了什么?
到目前为止,我们描述了如何考虑 Pallas 内核,但我们实际完成的是编写一个非常接近计算单元执行的函数,因为值被加载到内存层次结构中最内层(最快)的部分。
在 GPU 上,x_ref
对应于高带宽内存 (HBM) 中的值,当我们执行 x_ref[...]
时,我们将值从 HBM 复制到静态 RAM (SRAM)(一般来说,这是一个代价高昂的操作!)。然后我们使用 GPU 向量计算执行加法,然后将 SRAM 中的结果值复制回 HBM。
在 TPU 上,我们做了一些不同的事情。在内核执行之前,我们将值从 HBM 提取到 SRAM。x_ref
因此对应于 SRAM 中的值,当我们执行 x_ref[...]
时,我们正在将值从 SRAM 复制到寄存器。然后我们使用 TPU 向量计算执行加法,然后将结果值复制回 SRAM。内核执行后,SRAM 中的值将被复制回 HBM。
我们正在编写特定于后端的 Pallas 指南。即将推出!
Pallas 编程模型#
在我们的“Hello World”示例中,我们编写了一个非常简单的内核。它利用了我们的 8 个大小的数组可以舒适地放入硬件加速器的 SRAM 中这一事实。在大多数实际应用中,情况并非如此!
编写 Pallas 内核的一部分是考虑如何处理存在于高带宽内存 (HBM,也称为 DRAM) 中的大数组,并表达在可以放入 SRAM 的数组“块”上运行的计算。
Grid 示例#
为了自动“分割”输入和输出,您需要向 pallas_call
提供一个 grid
和 BlockSpec
s。
Grid 是一个整数元组(例如 ()
、(2, 3, 4)
或 (8,)
),它指定了一个迭代空间。例如,一个 grid (4, 5)
将有 20 个元素:(0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)
。我们为每个元素运行一次内核函数,这是一种单程序多数据 (SPMD) 编程风格。
一个二维 Grid
当我们将 grid
提供给 pallas_call
时,内核将执行 prod(grid)
次。每次调用都被称为一个“程序”。要访问内核当前正在执行的程序(即 grid 的哪个元素),我们使用 program_id(axis=...)
。例如,对于调用 (1, 2)
,program_id(axis=0)
返回 1
,program_id(axis=1)
返回 2
。
这是一个使用 grid
和 program_id
的示例内核。
def iota_kernel(o_ref):
i = pl.program_id(0)
o_ref[i] = i
现在我们使用带有附加 grid
参数的 pallas_call
来执行它。在 GPU 上,我们可以直接这样调用内核:
# GPU version
def iota(size: int):
return pl.pallas_call(iota_kernel,
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
grid=(size,))()
iota(8)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
TPU 会区分向量和标量内存空间,在这种情况下,输出必须放置在标量内存 (MemorySpace.SMEM
) 中,因为 i
是一个标量。有关更多详细信息,请阅读 TPU 及其内存空间。要在 TPU 上调用上述内核,请运行:
# TPU version
from jax.experimental.pallas import tpu as pltpu
def iota(size: int):
return pl.pallas_call(iota_kernel,
out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM),
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
grid=(size,))()
iota(8)
Grid 语义#
在 GPU 上,每个程序在不同的线程上并行执行。因此,我们需要考虑对 HBM 写入的竞争条件。一个合理的方法是这样编写我们的内核,即不同的程序写入 HBM 中不重叠的位置,以避免这些并行写入。另一方面,并行化计算是我们能够快速执行矩阵乘法等操作的方式。
相反,TPU 的工作方式类似于一个非常宽的 SIMD 机器。某些 TPU 型号包含多个核心,但在许多情况下,TPU 可以被视为一个单线程处理器。TPU 上的 grid 可以结合并行和顺序维度进行指定,其中顺序维度保证按顺序运行。
您可以在 Grid,即循环中的内核 和 值得注意的属性和限制 中阅读更多详细信息。
BlockSpec 示例#
考虑到 grid
和 program_id
,Pallas 提供了一个抽象,它处理了许多内核中常见的某些索引模式。为了建立直觉,让我们尝试实现一个矩阵乘法。
在 Pallas 中实现矩阵乘法的一个简单策略是递归实现。我们知道底层硬件支持小型矩阵乘法(使用 GPU 和 TPU 的张量核心),因此我们只需将大型矩阵乘法表示为小型矩阵乘法的组合。
假设我们有输入矩阵 \(X\) 和 \(Y\),并且正在计算 \(Z = XY\)。我们首先将 \(X\) 和 \(Y\) 表示为分块矩阵。\(X\) 将有“行”块,\(Y\) 将有“列”块。
我们的策略是,因为 \(Z\) 也是一个分块矩阵,我们可以将 Pallas 内核中的每个程序分配给一个输出块。计算每个输出块对应于执行“行”块 \(X\) 和“列”块 \(Y\) 之间的较小矩阵乘法。
为了表达这种模式,我们使用 BlockSpec
s。一个 BlockSpec
为每个输入和输出指定一个块形状,以及一个“索引映射”函数,该函数将一组程序索引映射到块索引。
BlockSpec 的可视化
为了举一个具体的例子,假设我们要将两个 (1024, 1024)
矩阵 x
和 y
相乘得到 z
,并且我们希望计算并行化 4 次。我们将 z
分割成 4 个 (512, 512)
的块,每个块通过 (512, 1024) x (1024, 512)
的矩阵乘法计算。为了表达这一点,我们首先使用 (2, 2)
的 grid(每个程序一个块)。
对于 x
,我们使用 BlockSpec((512, 1024), lambda i, j: (i, 0))
– 这会将 x
分割成“行”块。要理解这一点,请看程序实例 (1, 0)
和 (1, 1)
如何在 x
中选择 (1, 0)
块。对于 y
,我们使用转置版本 BlockSpec((1024, 512), lambda i, j: (0, j))
。最后,对于 z
,我们使用 BlockSpec((512, 512), lambda i, j: (i, j))
。
这些 BlockSpec
s 通过 in_specs
和 out_specs
传递到 pallas_call
。
有关 BlockSpec
s 的更多详细信息,请参阅 BlockSpec,即如何分割输入。
在底层,pallas_call
会自动将您的输入和输出分割成传递给内核的每个块的 Ref
s。
def matmul_kernel(x_ref, y_ref, z_ref):
z_ref[...] = x_ref[...] @ y_ref[...]
def matmul(x: jax.Array, y: jax.Array):
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec(
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),
)
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y)
np.testing.assert_allclose(z, x @ y)
请注意,这是一个非常朴素的矩阵乘法实现,但可以将其视为各种优化类型的起点。让我们为矩阵乘法添加一个附加功能:融合激活。实际上非常简单!只需将一个高阶激活函数传递到内核中。
def matmul_kernel(x_ref, y_ref, z_ref, *, activation):
z_ref[...] = activation(x_ref[...] @ y_ref[...])
def matmul(x: jax.Array, y: jax.Array, *, activation):
return pl.pallas_call(
partial(matmul_kernel, activation=activation),
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec(
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j)
),
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y, activation=jax.nn.relu)
np.testing.assert_allclose(z, jax.nn.relu(x @ y))
最后,让我们强调一个 Pallas 的炫酷功能:它可以与 jax.vmap
组合!要将此矩阵乘法转换为批处理版本,我们只需要对其进行 vmap
即可。
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (4, 1024, 1024))
y = jax.random.normal(k2, (4, 1024, 1024))
z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)
np.testing.assert_allclose(z, jax.nn.relu(jax.vmap(jnp.matmul)(x, y)))