TPU 流水线#

本指南旨在为 TPU 特有的流水线问题提供参考。我们将回顾 TPU 上的内存层次结构和计算单元,以及流水线 API 中 TPU 特有的功能。有关流水线操作的更通用概述,请参阅软件流水线

#@title Imports

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np

TPU 及其内存空间#

TPU 及其 TensorCore 包含内存空间(数组驻留的地方)、寄存器(临时存储标量和数组值)和计算单元(对寄存器中的值进行计算)。以下是 TPU 的示意图,其中 xy 是位于高带宽内存(HBM)中的数组

TPU Memory Space Cartoon.png

让我们更详细地讨论此图的组成部分

  • 内存空间:TPU 具有高带宽内存(HBM),通常我们将其视为“设备内存”。还有向量内存(VMEM),一个用于存储向量和数组值的缓存,以及标量内存(SMEM),一个用于存储标量值的缓存。

  • 寄存器:TensorCore 具有两种主要类型的寄存器:向量寄存器(VREGs)存储数组值,标量寄存器(SREGs)存储标量值。值可以从各自的缓存(VMEM 对应 VREGs,SMEM 对应 SREGs)加载到内存中。

  • 计算单元:TensorCore 具有标量单元、向量单元(VPU)和矩阵单元(MXU),可以进行数值计算。每个计算单元都可以异步操作,但这由 TPU 编译器管理,因此从程序员的角度来看,TPU 程序是单线程的。计算单元对 SREGs 和 VREGs 中的值进行操作,并将结果值也输出到这些寄存器中。

TPU 特有的流水线功能#

Pallas TPU 支持以下平台特定功能。

TPU 内存空间#

Pallas 向用户公开了 TPU 内存层次结构的所有级别。下表将 Pallas TPU 内存空间映射到其标准内存类型(DRAM/SRAM)

Pallas 枚举

TPU 内存空间

类型 (DRAM/SRAM)

pltpu.MemorySpace.ANY

HBM (通常) 或 VMEM

DRAM

pltpu.MemorySpace.VMEM

VMEM

SRAM

pltpu.MemorySpace.SMEM

SMEM

SRAM

pltpu.MemorySpace.SEMAPHORE

信号量

SRAM

  • MemorySpace.VMEM 表示向量 SRAM。如果没有指定,它是默认的内存空间。

  • MemorySpace.SMEM 表示标量 SRAM。只能对 SMEM 执行标量加载和存储操作。

  • MemorySpace.ANY 是对编译器的一个提示,表示内存空间不受限制。在大多数情况下,XLA 会将此缓冲区放置在 HBM 中。分配给 ANY 内存空间的缓冲区不能使用数组索引语法(例如 x[...])正常解引用。相反,我们必须首先使用 pltpu.sync_copypltpu.async_copy 将值复制到 VMEM 或 SMEM 缓冲区中。

  • MemorySpace.SEMAPHORE 用于分配信号量以构建屏障或跟踪异步操作。也可以从内核返回信号量以构建异步内核——这是一个实验性功能;有关更多详细信息,请参阅Pallas 异步操作

TPU 上的流水线操作通常在 HBM(DRAM)和 VMEM(向量 SRAM)之间进行。pallas_call 在 TPU 上的默认行为是:假设 pallas_call 的参数位于 HBM 中,并且用户内核主体的输入存储在 VMEM 中。

虽然不是流水线特有的功能,但可以手动控制输入和输出缓冲区的内存空间,您可以在 BlockSpec 上指定 memory_space 参数。请注意,除非 memory_space 标记为 VMEM,否则不允许进行流水线操作。内存空间还可以通过 pallas_call 上的 scratch_shapes 参数用于指定内核的临时参数。临时缓冲区在内核迭代中是持久的,对于存储中间结果(如部分累积和归约)很有用。临时缓冲区必须驻留在 VMEMSMEMSEMAPHORE 中。

作为在内核中使用多个手动内存空间分配的示例,以下程序将 HBM 缓冲区 x_hbm_ref 的一个切片复制到临时 VMEM 缓冲区 scratch_vmem_ref 中,然后将其用于算术运算并将结果存储到输出 VMEM 缓冲区中

def hbm_vmem_kernel(x_hbm_ref, out_vmem_ref, scratch_vmem_ref):
  pltpu.sync_copy(x_hbm_ref.at[0:1], scratch_vmem_ref)
  out_vmem_ref[...] = scratch_vmem_ref[...] + 1

x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)
out = pl.pallas_call(hbm_vmem_kernel,
  in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)],
  out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32),
  scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),)
)(x)

np.testing.assert_allclose(out, x[0:1] + 1)

Megacore 配置中的 TPU#

某些 TPU 芯片具有两个 TensorCore,但对 JAX 用户来说它们显示为一个设备。这被称为“megacore”。独立的 TensorCore 拥有各自独立的 VMEM、VREGs、SMEM、SREGs 和计算单元,但**共享 HBM**。

TPU Memory Space Cartoon (Megacore).png

从概念上讲,Megacore 中的 TPU 表现得像非常简单的 GPU,即它们只有两个线程。我们如何修改我们的内核以同时利用两个 TensorCore?

基本思想是,如果我们的计算中有易于并行化的维度,我们可以将这些维度拆分到不同的 TensorCore 上。我们可以通过向 pallas_call 提供一个名为 dimension_semantics 的注解来指示哪些维度是可并行化的。

def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):
  # Load x and y from VMEM into VREGs
  x_vregs = x_vmem_ref[:, :]
  y_vregs = y_vmem_ref[:, :]
  # Execute a vectorized add
  z_vregs = x_vregs + y_vregs
  # Store the output values in VREGs back into VMEM
  z_vmem_ref[:, :] = z_vregs

def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
  block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(2,),
      compiler_params=pltpu.CompilerParams(
          dimension_semantics=("parallel",))
  )(x, y)

x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

dimension_semantics 应该是一个与 grid 长度相同的元组,其中每个条目是 "parallel""arbitrary""parallel" 指示 Pallas 对应于该维度的 for 循环的迭代可以独立执行而不会影响程序的正确性。"arbitrary" 指示 Pallas 对此网格维度不能做任何假设,因此它无法并行化。

通过指定 dimension_semantics,我们现在可以在每个 TensorCore 上同时执行内核。Pallas 将自动处理网格的拆分。

请注意,Megacore 目前仅在 TPU v4 和 TPU v5p 上可用。在其他平台上提供 dimension_semantics 注解是空操作,但**不**指定它将导致只使用一个 TensorCore(即使有多个可用)。