TPU 流水线#
本指南旨在为 TPU 特有的流水线问题提供参考。我们将回顾 TPU 上的内存层次结构和计算单元,以及流水线 API 中 TPU 特有的功能。有关流水线操作的更通用概述,请参阅软件流水线。
#@title Imports
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
TPU 及其内存空间#
TPU 及其 TensorCore 包含内存空间(数组驻留的地方)、寄存器(临时存储标量和数组值)和计算单元(对寄存器中的值进行计算)。以下是 TPU 的示意图,其中 x
和 y
是位于高带宽内存(HBM)中的数组
让我们更详细地讨论此图的组成部分
内存空间:TPU 具有高带宽内存(HBM),通常我们将其视为“设备内存”。还有向量内存(VMEM),一个用于存储向量和数组值的缓存,以及标量内存(SMEM),一个用于存储标量值的缓存。
寄存器:TensorCore 具有两种主要类型的寄存器:向量寄存器(VREGs)存储数组值,标量寄存器(SREGs)存储标量值。值可以从各自的缓存(VMEM 对应 VREGs,SMEM 对应 SREGs)加载到内存中。
计算单元:TensorCore 具有标量单元、向量单元(VPU)和矩阵单元(MXU),可以进行数值计算。每个计算单元都可以异步操作,但这由 TPU 编译器管理,因此从程序员的角度来看,TPU 程序是单线程的。计算单元对 SREGs 和 VREGs 中的值进行操作,并将结果值也输出到这些寄存器中。
TPU 特有的流水线功能#
Pallas TPU 支持以下平台特定功能。
TPU 内存空间#
Pallas 向用户公开了 TPU 内存层次结构的所有级别。下表将 Pallas TPU 内存空间映射到其标准内存类型(DRAM/SRAM)
Pallas 枚举 |
TPU 内存空间 |
类型 (DRAM/SRAM) |
---|---|---|
|
HBM (通常) 或 VMEM |
DRAM |
|
VMEM |
SRAM |
|
SMEM |
SRAM |
|
信号量 |
SRAM |
MemorySpace.VMEM
表示向量 SRAM。如果没有指定,它是默认的内存空间。MemorySpace.SMEM
表示标量 SRAM。只能对 SMEM 执行标量加载和存储操作。MemorySpace.ANY
是对编译器的一个提示,表示内存空间不受限制。在大多数情况下,XLA 会将此缓冲区放置在 HBM 中。分配给ANY
内存空间的缓冲区不能使用数组索引语法(例如x[...]
)正常解引用。相反,我们必须首先使用pltpu.sync_copy
或pltpu.async_copy
将值复制到 VMEM 或 SMEM 缓冲区中。MemorySpace.SEMAPHORE
用于分配信号量以构建屏障或跟踪异步操作。也可以从内核返回信号量以构建异步内核——这是一个实验性功能;有关更多详细信息,请参阅Pallas 异步操作。
TPU 上的流水线操作通常在 HBM(DRAM)和 VMEM(向量 SRAM)之间进行。pallas_call
在 TPU 上的默认行为是:假设 pallas_call
的参数位于 HBM 中,并且用户内核主体的输入存储在 VMEM 中。
虽然不是流水线特有的功能,但可以手动控制输入和输出缓冲区的内存空间,您可以在 BlockSpec
上指定 memory_space
参数。请注意,除非 memory_space
标记为 VMEM
,否则不允许进行流水线操作。内存空间还可以通过 pallas_call
上的 scratch_shapes
参数用于指定内核的临时参数。临时缓冲区在内核迭代中是持久的,对于存储中间结果(如部分累积和归约)很有用。临时缓冲区必须驻留在 VMEM
、SMEM
或 SEMAPHORE
中。
作为在内核中使用多个手动内存空间分配的示例,以下程序将 HBM 缓冲区 x_hbm_ref
的一个切片复制到临时 VMEM 缓冲区 scratch_vmem_ref
中,然后将其用于算术运算并将结果存储到输出 VMEM 缓冲区中
def hbm_vmem_kernel(x_hbm_ref, out_vmem_ref, scratch_vmem_ref):
pltpu.sync_copy(x_hbm_ref.at[0:1], scratch_vmem_ref)
out_vmem_ref[...] = scratch_vmem_ref[...] + 1
x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)
out = pl.pallas_call(hbm_vmem_kernel,
in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)],
out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32),
scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),)
)(x)
np.testing.assert_allclose(out, x[0:1] + 1)
Megacore 配置中的 TPU#
某些 TPU 芯片具有两个 TensorCore,但对 JAX 用户来说它们显示为一个设备。这被称为“megacore”。独立的 TensorCore 拥有各自独立的 VMEM、VREGs、SMEM、SREGs 和计算单元,但**共享 HBM**。
从概念上讲,Megacore 中的 TPU 表现得像非常简单的 GPU,即它们只有两个线程。我们如何修改我们的内核以同时利用两个 TensorCore?
基本思想是,如果我们的计算中有易于并行化的维度,我们可以将这些维度拆分到不同的 TensorCore 上。我们可以通过向 pallas_call
提供一个名为 dimension_semantics
的注解来指示哪些维度是可并行化的。
def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):
# Load x and y from VMEM into VREGs
x_vregs = x_vmem_ref[:, :]
y_vregs = y_vmem_ref[:, :]
# Execute a vectorized add
z_vregs = x_vregs + y_vregs
# Store the output values in VREGs back into VMEM
z_vmem_ref[:, :] = z_vregs
def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,),
compiler_params=pltpu.CompilerParams(
dimension_semantics=("parallel",))
)(x, y)
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
dimension_semantics
应该是一个与 grid
长度相同的元组,其中每个条目是 "parallel"
或 "arbitrary"
。"parallel"
指示 Pallas 对应于该维度的 for 循环的迭代可以独立执行而不会影响程序的正确性。"arbitrary"
指示 Pallas 对此网格维度不能做任何假设,因此它无法并行化。
通过指定 dimension_semantics
,我们现在可以在每个 TensorCore 上同时执行内核。Pallas 将自动处理网格的拆分。
请注意,Megacore 目前仅在 TPU
v4
和 TPUv5p
上可用。在其他平台上提供dimension_semantics
注解是空操作,但**不**指定它将导致只使用一个 TensorCore(即使有多个可用)。