流水线#
在本指南中,我们将介绍 TPU 中内存空间的工作原理,以及如何在 Pallas 中编写流水线,使内存 I/O 与计算重叠。
#@title Imports
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
TPU 及其内存空间#
一个 TPU 及其 TensorCore 由内存空间(数组可以驻留的地方)、寄存器(临时存储标量和数组值)和计算单元(使用寄存器中的值进行计算)组成。下图是一个 TPU 的示意图,其中 x
和 y
是驻留在高带宽内存 (HBM) 中的数组
让我们更详细地讨论一下这个图的组成部分
内存空间:TPU 具有高带宽内存 (HBM),我们通常认为它是“设备内存”。还有向量内存 (VMEM),一种用于存储向量和数组值的缓存,以及标量内存 (SMEM),一种设计用于存储标量值的缓存。
寄存器:TensorCore 有两种主要类型的寄存器:向量寄存器 (VREG) 存储数组值,标量寄存器 (SREG) 存储标量值。值可以从它们各自的缓存(VREG 的 VMEM 和 SREG 的 SMEM)加载到内存中。
计算单元:TensorCore 具有标量单元、向量单元 (VPU) 和矩阵单元 (MXU),可以进行数值计算。计算单元对 SREG 和 VREG 中的值进行操作,并将输出值也存储在这些寄存器中。
为了对驻留在 HBM 中的值 x
和 y
进行向量化计算,我们需要
将值
x
和y
复制到 VMEM 中。将值从 VMEM 加载到 VREG 中。
使用 VPU 或 MXU 执行计算,并将输出存储在 VREG 中。
将输出 VREG 中的值存储到 VMEM 中。
将 VMEM 中的输出值复制回 HBM。
让我们实现一个 Pallas 函数来完成这项工作!
def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):
# Load x and y from VMEM into VREGs
x_vregs = x_vmem_ref[:, :]
y_vregs = y_vmem_ref[:, :]
# Execute a vectorized add
z_vregs = x_vregs + y_vregs
# Store the output values in VREGs back into VMEM
z_vmem_ref[:, :] = z_vregs
def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
# pallas_call will first allocate scratch buffers for `x` and `y` in VMEM.
# It will then copy `x` and `y` from HBM into VMEM.
z = pl.pallas_call(
add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
# pallas_call will also copy the output from VMEM back into HBM.
return z
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
我们编写了两个函数:add_matrices_kernel
和 add_matrices
。
add_matrices_kernel
使用驻留在 VMEM 中的 Ref
进行操作。从 VMEM Ref
加载会产生一个驻留在 VREG 中的值。VREG 中的值的行为类似于 jax.Array
,因为我们可以使用 jnp
和 jax.lax
操作对它们进行操作,以产生新的驻留在 VREG 中的值。当我们生成想要返回的值时,我们会将它们存储在输出 VMEM Ref
中。
add_matrices
函数作用于 jax.Array
并返回一个 jax.Array
。在函数内部,我们将 x
和 y
传递给 pallas_call
。pallas_call
负责将 x
和 y
复制到 VMEM 中,并负责分配内核操作的 VMEM 缓冲区(包括分配 z_vmem_ref
,输出 VMEM 缓冲区)。内核函数运行完成后,pallas_call
还会将 z_vmem_ref
中的值复制到 HBM,从而产生一个输出 jax.Array
。
使用 VMEM/SMEM 的约束#
Pallas 暴露了对较低级别内存空间(如 VMEM 和 SMEM)的访问,但编写利用它们的内核会增加一些考虑因素。
内存容量。VMEM 和 SMEM 都很**小**!v4 TPU 上的 VMEM 只有 16MiB,SMEM 的范围在几十到几百 KiB。如果我们的数组太大,我们甚至无法将它们全部放入 VMEM 中。作为参考,一个
f32[2048, 2048]
数组为 16MiB,因此我们上面的内核无法扩展到中等大小的数组之外。内存带宽。与大多数计算指令相比,在 HBM 和 VMEM 之间复制需要很长时间。
add_matrices
函数可能会花费更多时间在 HBM 和 VMEM 之间复制,而不是实际执行加法本身。
考虑到这两个约束,我们将不得不重新思考从 TPU 中获得性能的策略。
入门:流水线#
流水线计算提供了一种一次性处理内存容量和带宽约束的方法。我们所说的流水线是什么意思?
目标是:**并行**地在 HBM 和 VMEM 之间复制,**同时**利用我们的计算单元。天真地看,这很困难,因为在我们上面的程序中,我们在开始对 x
和 y
进行任何计算之前,复制了 x
和 y
的**所有**内容,从而在复制和计算之间创建了依赖关系。
但是,如果我们可以将计算分解为几个子计算(例如,当我们添加两个矩阵时,我们可以将其表示为原始矩阵的“块”的加法),我们现在可以将其中一个子计算的复制与另一个子计算的计算重叠。让我们来看一个简单的例子
假设我们将数组 x
和 y
分割为 x1、x2
和 y1、y2
(例如,沿前导轴分割,每个输入产生两个 (256, 512)
数组)。我们现在可以执行以下流水线计算。
将
x1
和y1
复制到 VMEM 中。开始将
x2
和y2
复制到 VMEM 中将
x1、y1
从 VMEM 加载到 VREG 中。使用计算单元执行
z1 = x1 + y1
。将
z1
存储到 VMEM 中。开始将
z1
从 VMEM 复制回 HBM。等待直到
x2、y2
已复制到 VMEM 中。将
x2、y2
从 VMEM 加载到 VREG 中。使用计算单元执行
z2 = x2 + y2
。将
z2
存储到 VMEM 中。等待直到
z1
复制到 HBM 中。开始将
z2
从 VMEM 复制回 HBM。等待直到
z2
复制到 HBM 中。
任何时候我们在这里进行计算,我们都在异步复制某些东西。这意味着一些复制时间不会被浪费。
用于确定流水线计算效率的两个最重要的数字是 a) 我们需要执行多少浮点运算 (FLOP) 和 b) 我们需要复制多少字节才能执行该计算。这两者之比(FLOP/内存使用量)称为操作的**算术强度**,并确定我们的流水线是受计算限制还是受内存限制。
Pallas 中的流水线#
我们如何在 Pallas 中实现像上面这样的流水线?这似乎是一个复杂的异步数据操作序列和执行内核,手动实现起来会很痛苦。别担心!Pallas 提供了一个 API 来表达流水线,而无需太多样板代码,即通过 grid
和 BlockSpec
。
请注意,在上面的流水线示例中,我们多次执行相同的逻辑:步骤 3-5 和 8-10 都执行相同的操作,只是输入不同。jax.experimental.pallas.pallas_call()
提供了一种通过使用 grid
参数多次执行内核的方法。请参阅 网格,又名循环中的内核。
我们还使用 jax.experimental.pallas.BlockSpec
来指定如何构造每个内核调用的输入。请参阅 BlockSpec,又名如何将输入分块。
在上面的流水线示例中,我们有 (512, 512)
形状的数组,并沿前导维度将它们拆分为两个 (256, 512)
形状的数组。在此流水线中,我们的 BlockSpec.block_shape
将为 (256, 512)
。在第一次迭代中,我们希望选择 x1
,在第二次迭代中,我们希望使用 x2
。这可以使用以下 index_map
表示
def x_index_map(i):
return (i, 0)
然后我们将构造 BlockSpec
block_spec = pl.BlockSpec((256, 512), x_index_map)
y
和 z
的 BlockSpec
将与 x
的相同。
整合在一起#
我们通过 grid
、in_specs
和 out_specs
将这些参数提供给 pallas_call
(in_specs
对应于位置参数的元组,out_specs
对应于输出)。
def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,)
)(x, y)
add_matrices_pipelined(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
我们只在原始函数中添加了一点代码来添加自动流水线,但 BlockSpec
和 grid
完成了很多繁重的工作!
它是如何工作的?好吧,BlockSpec
提供了足够的信息来开始从 HBM **预取**输入的块到 VMEM 中。例如,如果我们正在启动 grid
的迭代 i
,我们可以将 i + 1
传递到 index_map
函数中,以获取下一次迭代所需的块。然后我们可以启动这些块的异步复制。与输出类似,我们可以等待上一次迭代的输出被复制,然后再启动当前迭代输出的复制。
参数化流水线#
在我们的内核中参数化块形状是很常见的。块大小可能是优化 Pallas 内核性能时最重要的参数!它们让我们控制流水线(例如,选择较小的块会为我们的流水线循环添加更多迭代,其中每次迭代的工作量较少)。
此外,我们还可以沿第二个维度划分输入和输出(我们现在只沿第一个维度分割)。让我们编写一个更通用的内核来处理这两个特性。
def add_matrices_pipelined_2d(
x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
m, n = x.shape
block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(m // bm, n // bn),
)(x, y)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y
)
处理归约#
你将如何使用 pallas_call
实现类似 jnp.sum
的功能?具体来说,我们希望在归约维度上进行流水线操作。
以将 (8, 512, 512)
形状的数组归约为 (512, 512)
形状的数组为例。
x = jnp.ones((8, 512, 512))
jnp.sum(x, axis=0)
Array([[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
...,
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.]], dtype=float32)
为了使用 pallas_call
执行此操作,我们可以使用大小为 (8,)
的网格,并在每次迭代 i
中将 x[i]
加载到 VMEM 中。然后我们可以将 x[i]
添加到输出 VMEM 缓冲区。让我们先天真地实现这一点。
# Warning: this implementation is incorrect!
def naive_sum_kernel(x_ref, o_ref):
o_ref[...] += x_ref[...]
def naive_sum(x: jax.Array) -> jax.Array:
grid, *out_shape = x.shape
return pl.pallas_call(
naive_sum_kernel,
grid=grid,
# None in `block_shape` means we pick a size of 1 and squeeze it away
in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
)(x)
naive_sum(x)
Array([[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
...,
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.]], dtype=float32)
请注意我们是如何设置 BlockSpec
的:我们正在将 (512, 512)
维度的全部内容加载到 VMEM 中(那里没有流水线),但在 index_map
中每次迭代都选择 x
的第 i
个维度。我们在块形状中对该维度使用 None
,这表示我们正在从 x
中选择一个单例维度,我们希望在内核中将其挤压掉。因此,x_ref
在 VMEM 中也是 (512, 512)
形状的。
out_spec
使用 lambda i: (0, 0)
作为其 index_map
,表示 o_ref
在流水线的过程中保持不变。这意味着我们可以通过读取和写入来在每次迭代中更新其值。或者可以吗?实际上有一个陷阱:**o_ref
最初是垃圾**,这意味着我们将累积到垃圾中。这将导致整个函数输出不正确的值!
因此,**每当我们在内核中进行归约时,我们需要确保初始化存储归约值的 Ref
**。我们可以通过在迭代 0 时有条件地将值写入 out_ref
来完成此操作。我们可以使用辅助函数 pl.when
(jax.lax.cond
的便捷包装器)和 pl.program_id
(查询我们在网格轴中的哪个迭代中)来完成此操作。
def sum_kernel(x_ref, o_ref):
@pl.when(pl.program_id(axis=0) == 0)
def _():
o_ref[...] = jnp.zeros_like(o_ref)
o_ref[...] += x_ref[...]
def sum(x: jax.Array) -> jax.Array:
grid, *out_shape = x.shape
return pl.pallas_call(
sum_kernel,
grid=grid,
# None in `block_shape` means we pick a size of 1 and squeeze it away
in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
)(x)
sum(x)
Array([[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
...,
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.]], dtype=float32)
这个 sum
函数现在输出正确的值了!
关于 Pallas 中的归约,最后需要注意的一点是,**它们必须在网格的最小维度(最右侧维度)中完成**(在上面的示例中,我们的网格是一维的,因此我们正在对其最小维度进行归约)。这是因为 Pallas 使用 BlockSpec
、grid
和内核函数生成的流水线**不会从 HBM 中读取输出**。一旦你将输出值写回 HBM,你就无法再次访问它。因此,你无法跨任何具有重新访问的网格维度进行归约,因此所有归约都需要在最右侧维度中进行。
Megacore 配置中的 TPU#
一些 TPU 芯片有两个 TensorCore,但在 JAX 用户看来,它们就像一个设备。这称为 “megacore”。单独的 TensorCore 有其自己单独的 VMEM、VREG、SMEM、SREG 和计算单元,但**共享 HBM**。
从概念上讲,Megacore 中的 TPU 的行为类似于非常简单的 GPU,即它们只有两个线程。我们如何修改内核以同时利用两个 TensorCore?
基本思想是,如果我们的计算中存在易于并行化的维度,我们可以将这些维度分散到 TensorCore 上。我们可以通过向 pallas_call
提供一个名为 dimension_semantics
的注解来指示哪些维度是可并行化的。
def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,),
compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",))
)(x, y)
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
dimension_semantics
应该是一个与 grid
长度相同的元组,其中每个条目是 "parallel"
或 "arbitrary"
。"parallel"
向 Pallas 表明,对应于该维度的 for 循环的迭代可以独立执行,而不会影响程序的正确性。"arbitrary"
向 Pallas 表明,对于此网格维度不能做任何假设,因此它不能被并行化。
通过指定 dimension_semantics
,我们现在在每个 TensorCore 上同时执行内核。 Pallas 将自动处理网格的分割。
请注意,Megacore 目前仅在 TPU
v4
和 TPUv5p
上可用。 在其他平台上提供dimension_semantics
注解不起作用,但是*不*指定它将导致仅使用一个 TensorCore(即使有多个可用)。
结论#
在本指南中,我们介绍了如何使用 pallas_call
、grid
和 BlockSpec
来表达 TPU 管道。 我们介绍了如何通过多维网格来表达嵌套循环,以及如何通过在归约开始时初始化累加器来处理归约。 我们还学习了如何通过向内核添加注解来处理 Megacore。
留给读者的练习
尝试实现一个
sum
内核,该内核也对其他维度进行流水线处理向
add
内核和sum
内核添加 megacore 支持。