Pallas 快速入门#
Pallas 是 JAX 的一个扩展,支持为 GPU 和 TPU 编写自定义内核。Pallas 允许您使用相同的 JAX 函数和 API,但在一个更低的抽象级别上运行。
具体来说,Pallas 要求用户考虑内存访问以及如何在硬件加速器中的多个计算单元之间划分计算。在 GPU 上,Pallas 降级到 Triton;在 TPU 上,Pallas 降级到 Mosaic。
让我们深入了解一些示例。
注意:Pallas 仍是一个实验性 API,可能会因更改而中断!
Pallas 中的 Hello World#
from functools import partial
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
我们首先在 Pallas 中编写“hello world”,这是一个用于添加两个向量的内核。
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
Ref
类型
让我们稍微分析一下这个函数。与您可能编写过的大多数 JAX 函数不同,它不接受 jax.Array
作为输入,也不返回任何值。相反,它接受 *Ref
* 对象作为输入,这些对象表示内存中的可变缓冲区。请注意,我们也没有任何输出,但我们得到了一个 o_ref
,它对应于所需的输出。
从 Ref
读取
在函数体中,我们首先从 x_ref
和 y_ref
读取,由 [...]
表示(省略号表示我们正在读取整个 Ref
;或者我们也可以使用 x_ref[:]
)。以这种方式从 Ref
读取会返回一个 jax.Array
。
写入 Ref
然后我们将 x + y
写入 o_ref
。JAX 历来不支持变异——jax.Array
是不可变的!Ref
是新的(实验性)类型,允许在某些情况下进行变异。我们可以将写入 Ref
理解为修改其底层缓冲区。
因此,我们编写了一个所谓的“内核”,我们将其定义为一个将在加速器上作为原子执行单元运行而无需与主机进行任何交互的程序。我们如何从 JAX 计算中调用它?我们使用 pallas_call
高阶函数。
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(
add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
add_vectors(jnp.arange(8), jnp.arange(8))
Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)
pallas_call
将 Pallas 内核函数提升为一个操作,可以作为更大 JAX 程序的一部分调用。但是,为此,它需要更多细节。这里我们指定 out_shape
,它是一个具有 .shape
和 .dtype
的对象(或它们的列表)。out_shape
决定了我们 add_vector_kernel
中 o_ref
的形状/数据类型。
pallas_call
返回一个接受并返回 jax.Array
的函数。
这里实际发生了什么?
到目前为止,我们已经描述了如何思考 Pallas 内核,但我们实际完成的是编写一个函数,该函数在离计算单元非常近的地方执行,因为值被加载到内存层次结构中最内部(最快)的部分。
在 GPU 上,x_ref
对应于高带宽内存 (HBM) 中的一个值,当我们执行 x_ref[...]
时,我们正在将值从 HBM 复制到静态随机存取存储器 (SRAM) 中(通常来说,这是一项代价高昂的操作!)。然后我们使用 GPU 向量计算执行加法,再将 SRAM 中得到的值复制回 HBM。
在 TPU 上,我们做了一些稍微不同的事情。在内核执行之前,我们先将值从 HBM 取入 SRAM。x_ref
因此对应于 SRAM 中的一个值,当我们执行 x_ref[...]
时,我们正在将值从 SRAM 复制到一个寄存器中。然后我们使用 TPU 向量计算执行加法,再将得到的值复制回 SRAM。内核执行后,SRAM 值被复制回 HBM。
我们正在编写针对后端特性的 Pallas 指南。敬请期待!
Pallas 编程模型#
在我们的“hello world”示例中,我们编写了一个非常简单的内核。它利用了我们 8 大小数组可以轻松适应硬件加速器 SRAM 的事实。在大多数实际应用中,情况并非如此!
编写 Pallas 内核的一部分是思考如何处理存在于高带宽内存(HBM,也称为 DRAM)中的大型数组,并表达对这些数组的“块”进行操作的计算,这些“块”可以适应 SRAM。
网格示例#
为了自动“划分”输入和输出,您需要向 pallas_call
提供一个 grid
和 BlockSpec
。
一个 grid
是一个整数元组(例如 ()
, (2, 3, 4)
或 (8,)
),它指定了一个迭代空间。例如,一个 (4, 5)
的网格将有 20 个元素:(0, 0), (0, 1), ...,
(0, 4),
(1, 0), ...,
(3, 4)
。我们为每个元素运行一次内核函数,这是一种单程序多数据 (SPMD) 编程风格。
一个二维网格
当我们向 pallas_call
提供一个 grid
时,内核会执行 prod(grid)
次。这些调用中的每一个都称为一个“程序”。要访问内核当前正在执行哪个程序(即网格的哪个元素),我们使用 program_id(axis=...)
。例如,对于调用 (1, 2)
,program_id(axis=0)
返回 1
,而 program_id(axis=1)
返回 2
。
这是一个使用 grid
和 program_id
的内核示例。
def iota_kernel(o_ref):
i = pl.program_id(0)
o_ref[i] = i
我们现在使用带有附加 grid
参数的 pallas_call
来执行它。在 GPU 上,我们可以直接调用内核,如下所示
# GPU version
def iota(size: int):
return pl.pallas_call(iota_kernel,
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
grid=(size,))()
iota(8)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
TPU 区分向量和标量内存空间,在这种情况下,输出必须放在标量内存中 (MemorySpace.SMEM
),因为 i
是一个标量。更多细节请阅读 TPU 及其内存空间。要在 TPU 上调用上述内核,请运行
# TPU version
from jax.experimental.pallas import tpu as pltpu
def iota(size: int):
return pl.pallas_call(iota_kernel,
out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM),
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
grid=(size,))()
iota(8)
网格语义#
在 GPU 上,每个程序都在单独的线程上并行执行。因此,我们需要考虑写入 HBM 时的竞态条件。一种合理的方法是,编写内核时,让不同的程序写入 HBM 中不相交的位置,以避免这些并行写入。另一方面,并行化计算是我们能够非常快速地执行矩阵乘法等操作的方式。
相比之下,TPU 运行起来像一个非常宽的 SIMD 机器。一些 TPU 模型包含多个核心,但在许多情况下,TPU 可以被视为一个单线程处理器。TPU 上的网格可以指定为并行和顺序维度的组合,其中顺序维度保证串行运行。
您可以在 网格,又名循环中的内核 和 值得注意的属性和限制 中阅读更多详细信息。
块规范示例#
考虑到 grid
和 program_id
,Pallas 提供了一个抽象,可以处理许多内核中常见的索引模式。为了建立直观理解,我们来尝试实现一个矩阵乘法。
在 Pallas 中实现矩阵乘法的一个简单策略是递归实现。我们知道底层硬件支持小型矩阵乘法(使用 GPU 和 TPU 张量核心),所以我们只需将大型矩阵乘法表达为较小矩阵乘法的组合。
假设我们有输入矩阵 \(X\) 和 \(Y\),并且正在计算 \(Z = XY\)。我们首先将 \(X\) 和 \(Y\) 表示为块矩阵。\(X\) 将具有“行”块,而 \(Y\) 将具有“列”块。
我们的策略是,因为 \(Z\) 也是一个块矩阵,我们可以将 Pallas 内核中的每个程序分配给一个输出块。计算每个输出块对应于在 \(X\) 的“行”块和 \(Y\) 的“列”块之间进行一次较小的矩阵乘法。
为了表达这种模式,我们使用 BlockSpec
。一个 BlockSpec
为每个输入和输出指定一个块形状,以及一个“索引映射”函数,该函数将一组程序索引映射到一个块索引。
BlockSpec
的可视化
举一个具体例子,假设我们想将两个 (1024, 1024)
矩阵 x
和 y
相乘得到 z
,并希望将计算并行化为 4 种方式。我们将 z
分成 4 个 (512, 512)
块,每个块都通过一个 (512, 1024) x (1024, 512)
的矩阵乘法进行计算。为了表达这一点,我们首先会使用一个 (2, 2)
的网格(每个程序一个块)。
对于 x
,我们使用 BlockSpec((512, 1024), lambda i, j: (i, 0))
——这会将 x
划分为“行”块。要理解这一点,请看程序实例 (1, 0)
和 (1, 1)
如何都选择 x
中的 (1, 0)
块。对于 y
,我们使用转置版本 BlockSpec((1024, 512), lambda i, j: (0, j))
。最后,对于 z
,我们使用 BlockSpec((512, 512), lambda i, j: (i, j))
。
这些 BlockSpec
通过 in_specs
和 out_specs
传递给 pallas_call
。
有关 BlockSpec
的更多细节,请参阅 BlockSpec,又名如何分块输入。
在底层,pallas_call
将自动把您的输入和输出分割成 Ref
,每个块都将作为参数传递给内核。
def matmul_kernel(x_ref, y_ref, z_ref):
z_ref[...] = x_ref[...] @ y_ref[...]
def matmul(x: jax.Array, y: jax.Array):
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec(
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),
)
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y)
np.testing.assert_allclose(z, x @ y)
请注意,这只是一个非常简单的矩阵乘法实现,但请将其视为各种优化的起点。让我们为矩阵乘法添加一个额外功能:融合激活。这实际上非常容易!只需将一个高阶激活函数传递给内核即可。
def matmul_kernel(x_ref, y_ref, z_ref, *, activation):
z_ref[...] = activation(x_ref[...] @ y_ref[...])
def matmul(x: jax.Array, y: jax.Array, *, activation):
return pl.pallas_call(
partial(matmul_kernel, activation=activation),
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec(
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j)
),
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y, activation=jax.nn.relu)
np.testing.assert_allclose(z, jax.nn.relu(x @ y))
最后,让我们强调 Pallas 的一个很酷的特性:它与 jax.vmap
结合使用!要将此矩阵乘法转换为批处理版本,我们只需对其进行 vmap
操作即可。
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (4, 1024, 1024))
y = jax.random.normal(k2, (4, 1024, 1024))
z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)
np.testing.assert_allclose(z, jax.nn.relu(jax.vmap(jnp.matmul)(x, y)))