Pallas 快速入门#
Pallas 是 JAX 的一个扩展,支持为 GPU 和 TPU 编写自定义内核。Pallas 允许您使用相同的 JAX 函数和 API,但在更低的抽象级别上运行。
具体来说,Pallas 要求用户考虑内存访问以及如何在硬件加速器中的多个计算单元之间分配计算。在 GPU 上,Pallas 降低到 Triton,而在 TPU 上,Pallas 降低到 Mosaic。
让我们深入研究一些示例。
注意:Pallas 仍然是一个实验性 API,您可能会因更改而中断!
Pallas 中的 Hello world#
from functools import partial
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
我们将首先用 Pallas 编写 “hello world”,这是一个添加两个向量的内核。
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
Ref
类型
让我们剖析一下这个函数。与您可能编写的大多数 JAX 函数不同,它不接受 jax.Array
作为输入,也不返回任何值。相反,它接受 Ref
对象作为输入,这些对象表示内存中的可变缓冲区。请注意,我们也没有任何输出,但我们得到了一个 o_ref
,它对应于所需的输出。
从 Ref
读取
在主体中,我们首先从 x_ref
和 y_ref
读取,用 [...]
表示(省略号表示我们正在读取整个 Ref
;或者我们也可以使用 x_ref[:]
)。像这样从 Ref
读取会返回一个 jax.Array
。
写入 Ref
然后我们将 x + y
写入 o_ref
。jax.Array
是不可变的!Ref
是新的(实验性)类型,允许在某些情况下进行变异。我们可以将写入 Ref
解释为变异其底层缓冲区。
因此,我们编写了一个所谓的 “内核”,我们将其定义为一个程序,它将作为加速器上的原子执行单元运行,而无需与主机进行任何交互。我们如何从 JAX 计算中调用它?我们使用 pallas_call
高阶函数。
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(
add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
add_vectors(jnp.arange(8), jnp.arange(8))
Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)
pallas_call
将 Pallas 内核函数提升为一个可以作为更大 JAX 程序一部分调用的操作。但是,要做到这一点,它需要更多细节。在这里,我们指定 out_shape
,一个具有 .shape
和 .dtype
的对象(或其列表)。out_shape
决定了我们 add_vector_kernel
中 o_ref
的形状/数据类型。
pallas_call
返回一个函数,该函数接受和返回 jax.Array
。
这里实际发生了什么?
到目前为止,我们已经描述了如何思考 Pallas 内核,但我们实际完成的是我们正在编写一个非常接近计算单元执行的函数,因为值被加载到内存层次结构的最内层(最快)部分。
在 GPU 上,x_ref
对应于高带宽内存 (HBM) 中的一个值,当我们执行 x_ref[...]
时,我们将该值从 HBM 复制到静态 RAM (SRAM)(一般来说,这是一个代价高昂的操作!)。然后我们使用 GPU 向量计算来执行加法,然后将 SRAM 中的结果值复制回 HBM。
在 TPU 上,我们做一些稍微不同的事情。在内核执行之前,我们将值从 HBM 获取到 SRAM。x_ref
因此对应于 SRAM 中的一个值,当我们执行 x_ref[...]
时,我们将该值从 SRAM 复制到寄存器。然后我们使用 TPU 向量计算来执行加法,然后将结果值复制回 SRAM。内核执行后,SRAM 值被复制回 HBM。
我们正在编写特定于后端的 Pallas 指南。即将推出!
Pallas 编程模型#
在我们的 “hello world” 示例中,我们编写了一个非常简单的内核。它利用了我们的 8 大小的数组可以舒适地容纳在硬件加速器的 SRAM 中的事实。在大多数实际应用中,情况并非如此!
编写 Pallas 内核的一部分是思考如何获取驻留在高带宽内存(HBM,也称为 DRAM)中的大型数组,并表达对可以容纳在 SRAM 中的这些数组的 “块” 进行操作的计算。
网格示例#
为了自动 “划分” 输入和输出,您需要为 pallas_call
提供 grid
和 BlockSpec
。
grid
是一个整数元组(例如 ()
、(2, 3, 4)
或 (8,)
),用于指定迭代空间。例如,网格 (4, 5)
将有 20 个元素:(0, 0), (0, 1), ..., (0, 4), (1, 0), ..., (3, 4)
。我们为每个元素运行一次内核函数,这是一种单程序多数据 (SPMD) 编程风格。
二维网格
当我们为 pallas_call
提供 grid
时,内核将执行 prod(grid)
次。这些调用中的每一个都称为 “程序”。要访问内核当前正在执行哪个程序(即网格的哪个元素),我们使用 program_id(axis=...)
。例如,对于调用 (1, 2)
,program_id(axis=0)
返回 1
,而 program_id(axis=1)
返回 2
。
这是一个使用 grid
和 program_id
的示例内核。
def iota_kernel(o_ref):
i = pl.program_id(0)
o_ref[i] = i
我们现在使用带有附加 grid
参数的 pallas_call
执行它。在 GPU 上,我们可以像这样直接调用内核
# GPU version
def iota(size: int):
return pl.pallas_call(iota_kernel,
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
grid=(size,))()
iota(8)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
TPU 区分向量和标量内存空间,在这种情况下,输出必须放置在标量内存 (TPUMemorySpace.SMEM
) 中,因为 i
是标量。有关更多详细信息,请阅读 TPU 及其内存空间。要在 TPU 上调用上述内核,请运行
# TPU version
from jax.experimental.pallas import tpu as pltpu
def iota(size: int):
return pl.pallas_call(iota_kernel,
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
grid=(size,))()
iota(8)
网格语义#
在 GPU 上,每个程序都在单独的线程上并行执行。因此,我们需要考虑写入 HBM 的竞争条件。一种合理的方法是以这样一种方式编写内核,即不同的程序写入 HBM 中的不相交位置,以避免这些并行写入。另一方面,并行化计算是我们能够非常快速地执行矩阵乘法等操作的方式。
相比之下,TPU 的运行方式类似于非常宽的 SIMD 机器。一些 TPU 模型包含多个内核,但在许多情况下,TPU 可以被视为单线程处理器。TPU 上的网格可以在并行和顺序维度的组合中指定,其中顺序维度保证串行运行。
您可以在 网格,又名循环中的内核 和 值得注意的属性和限制 中阅读更多详细信息。
块规范示例#
考虑到 grid
和 program_id
,Pallas 提供了一个抽象,可以处理在许多内核中看到的一些常见索引模式。为了建立直觉,让我们尝试实现矩阵乘法。
在 Pallas 中实现矩阵乘法的一个简单策略是递归地实现它。我们知道我们的底层硬件支持小型矩阵乘法(使用 GPU 和 TPU 张量核心),因此我们只需用较小的矩阵乘法来表示大型矩阵乘法。
假设我们有输入矩阵 \(X\) 和 \(Y\),并且正在计算 \(Z = XY\)。我们首先将 \(X\) 和 \(Y\) 表示为块矩阵。\(X\) 将具有 “行” 块,而 \(Y\) 将具有 “列” 块。
我们的策略是,因为 \(Z\) 也是一个块矩阵,我们可以为 Pallas 内核中的每个程序分配一个输出块。计算每个输出块对应于在 \(X\) 的 “行” 块和 \(Y\) 的 “列” 块之间进行较小的矩阵乘法。
为了表达这种模式,我们使用 BlockSpec
。BlockSpec
为每个输入和输出指定一个块形状,以及一个 “索引映射” 函数,该函数将一组程序索引映射到块索引。
BlockSpec
的可视化
对于一个具体的例子,假设我们要将两个 (1024, 1024)
矩阵 x
和 y
相乘得到 z
,并希望以 4 种方式并行化计算。我们将 z
分成 4 个 (512, 512)
块,其中每个块都使用 (512, 1024) x (1024, 512)
矩阵乘法计算。为了表达这一点,我们首先使用一个 (2, 2)
网格(每个程序一个块)。
对于 x
,我们使用 BlockSpec((512, 1024), lambda i, j: (i, 0))
– 这将 x
划分成 “行” 块。要看到这一点,请查看程序实例 (1, 0)
和 (1, 1)
如何在 x
中选择 (1, 0)
块。对于 y
,我们使用转置版本 BlockSpec((1024, 512), lambda i, j: (0, j))
。最后,对于 z
,我们使用 BlockSpec((512, 512), lambda i, j: (i, j))
。
这些 BlockSpec
通过 in_specs
和 out_specs
传递到 pallas_call
中。
有关 BlockSpec
的更多详细信息,请参阅 BlockSpec,又名如何将输入分块。
在底层,pallas_call
将自动将您的输入和输出划分为 Ref
,用于将传递到内核的每个块。
def matmul_kernel(x_ref, y_ref, z_ref):
z_ref[...] = x_ref[...] @ y_ref[...]
def matmul(x: jax.Array, y: jax.Array):
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec(
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),
)
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y)
np.testing.assert_allclose(z, x @ y)
请注意,这是矩阵乘法的一个非常幼稚的实现,但请将其视为各种优化的起点。让我们为矩阵乘法添加一个额外的功能:融合激活。实际上非常容易!只需将高阶激活函数传递到内核中即可。
def matmul_kernel(x_ref, y_ref, z_ref, *, activation):
z_ref[...] = activation(x_ref[...] @ y_ref[...])
def matmul(x: jax.Array, y: jax.Array, *, activation):
return pl.pallas_call(
partial(matmul_kernel, activation=activation),
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
],
out_specs=pl.BlockSpec(
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j)
),
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y, activation=jax.nn.relu)
np.testing.assert_allclose(z, jax.nn.relu(x @ y))
总而言之,让我们强调 Pallas 的一个很酷的功能:它与 jax.vmap
组合!要将此矩阵乘法转换为批处理版本,我们只需要 vmap
它。
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (4, 1024, 1024))
y = jax.random.normal(k2, (4, 1024, 1024))
z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)
np.testing.assert_allclose(z, jax.nn.relu(jax.vmap(jnp.matmul)(x, y)))