流水线#

在本指南中,我们将介绍 TPU 中内存空间的工作原理,以及如何在 Pallas 中编写流水线,使内存 I/O 与计算重叠。

#@title Imports

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np

TPU 及其内存空间#

一个 TPU 及其 TensorCore 由内存空间(数组可以驻留的地方)、寄存器(临时存储标量和数组值)和计算单元(使用寄存器中的值进行计算)组成。下图是一个 TPU 的示意图,其中 xy 是驻留在高带宽内存 (HBM) 中的数组

TPU Memory Space Cartoon.png

让我们更详细地讨论一下这个图的组成部分

  • 内存空间:TPU 具有高带宽内存 (HBM),我们通常认为它是“设备内存”。还有向量内存 (VMEM),一种用于存储向量和数组值的缓存,以及标量内存 (SMEM),一种设计用于存储标量值的缓存。

  • 寄存器:TensorCore 有两种主要类型的寄存器:向量寄存器 (VREG) 存储数组值,标量寄存器 (SREG) 存储标量值。值可以从它们各自的缓存(VREG 的 VMEM 和 SREG 的 SMEM)加载到内存中。

  • 计算单元:TensorCore 具有标量单元、向量单元 (VPU) 和矩阵单元 (MXU),可以进行数值计算。计算单元对 SREG 和 VREG 中的值进行操作,并将输出值也存储在这些寄存器中。

为了对驻留在 HBM 中的值 xy 进行向量化计算,我们需要

  1. 将值 xy 复制到 VMEM 中。

  2. 将值从 VMEM 加载到 VREG 中。

  3. 使用 VPU 或 MXU 执行计算,并将输出存储在 VREG 中。

  4. 将输出 VREG 中的值存储到 VMEM 中。

  5. 将 VMEM 中的输出值复制回 HBM。

让我们实现一个 Pallas 函数来完成这项工作!

def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):
  # Load x and y from VMEM into VREGs
  x_vregs = x_vmem_ref[:, :]
  y_vregs = y_vmem_ref[:, :]
  # Execute a vectorized add
  z_vregs = x_vregs + y_vregs
  # Store the output values in VREGs back into VMEM
  z_vmem_ref[:, :] = z_vregs


def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
  # pallas_call will first allocate scratch buffers for `x` and `y` in VMEM.
  # It will then copy `x` and `y` from HBM into VMEM.
  z = pl.pallas_call(
      add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
  )(x, y)
  # pallas_call will also copy the output from VMEM back into HBM.
  return z


x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

我们编写了两个函数:add_matrices_kerneladd_matrices

add_matrices_kernel 使用驻留在 VMEM 中的 Ref 进行操作。从 VMEM Ref 加载会产生一个驻留在 VREG 中的值。VREG 中的值的行为类似于 jax.Array,因为我们可以使用 jnpjax.lax 操作对它们进行操作,以产生新的驻留在 VREG 中的值。当我们生成想要返回的值时,我们会将它们存储在输出 VMEM Ref 中。

add_matrices 函数作用于 jax.Array 并返回一个 jax.Array。在函数内部,我们将 xy 传递给 pallas_callpallas_call 负责将 xy 复制到 VMEM 中,并负责分配内核操作的 VMEM 缓冲区(包括分配 z_vmem_ref,输出 VMEM 缓冲区)。内核函数运行完成后,pallas_call 还会将 z_vmem_ref 中的值复制到 HBM,从而产生一个输出 jax.Array

使用 VMEM/SMEM 的约束#

Pallas 暴露了对较低级别内存空间(如 VMEM 和 SMEM)的访问,但编写利用它们的内核会增加一些考虑因素。

  1. 内存容量。VMEM 和 SMEM 都很**小**!v4 TPU 上的 VMEM 只有 16MiB,SMEM 的范围在几十到几百 KiB。如果我们的数组太大,我们甚至无法将它们全部放入 VMEM 中。作为参考,一个 f32[2048, 2048] 数组为 16MiB,因此我们上面的内核无法扩展到中等大小的数组之外。

  2. 内存带宽。与大多数计算指令相比,在 HBM 和 VMEM 之间复制需要很长时间。add_matrices 函数可能会花费更多时间在 HBM 和 VMEM 之间复制,而不是实际执行加法本身。

考虑到这两个约束,我们将不得不重新思考从 TPU 中获得性能的策略。

入门:流水线#

流水线计算提供了一种一次性处理内存容量和带宽约束的方法。我们所说的流水线是什么意思?

目标是:**并行**地在 HBM 和 VMEM 之间复制,**同时**利用我们的计算单元。天真地看,这很困难,因为在我们上面的程序中,我们在开始对 xy 进行任何计算之前,复制了 xy 的**所有**内容,从而在复制和计算之间创建了依赖关系。

但是,如果我们可以将计算分解为几个子计算(例如,当我们添加两个矩阵时,我们可以将其表示为原始矩阵的“块”的加法),我们现在可以将其中一个子计算的复制与另一个子计算的计算重叠。让我们来看一个简单的例子

假设我们将数组 xy 分割为 x1、x2y1、y2(例如,沿前导轴分割,每个输入产生两个 (256, 512) 数组)。我们现在可以执行以下流水线计算。

  1. x1y1 复制到 VMEM 中。

  2. 开始将 x2y2 复制到 VMEM 中

  3. x1、y1 从 VMEM 加载到 VREG 中。

  4. 使用计算单元执行 z1 = x1 + y1

  5. z1 存储到 VMEM 中。

  6. 开始将 z1 从 VMEM 复制回 HBM。

  7. 等待直到 x2、y2 已复制到 VMEM 中。

  8. x2、y2 从 VMEM 加载到 VREG 中。

  9. 使用计算单元执行 z2 = x2 + y2

  10. z2 存储到 VMEM 中。

  11. 等待直到 z1 复制到 HBM 中。

  12. 开始将 z2 从 VMEM 复制回 HBM。

  13. 等待直到 z2 复制到 HBM 中。

任何时候我们在这里进行计算,我们都在异步复制某些东西。这意味着一些复制时间不会被浪费。

用于确定流水线计算效率的两个最重要的数字是 a) 我们需要执行多少浮点运算 (FLOP) 和 b) 我们需要复制多少字节才能执行该计算。这两者之比(FLOP/内存使用量)称为操作的**算术强度**,并确定我们的流水线是受计算限制还是受内存限制。

Pallas 中的流水线#

我们如何在 Pallas 中实现像上面这样的流水线?这似乎是一个复杂的异步数据操作序列和执行内核,手动实现起来会很痛苦。别担心!Pallas 提供了一个 API 来表达流水线,而无需太多样板代码,即通过 gridBlockSpec

请注意,在上面的流水线示例中,我们多次执行相同的逻辑:步骤 3-5 和 8-10 都执行相同的操作,只是输入不同。jax.experimental.pallas.pallas_call() 提供了一种通过使用 grid 参数多次执行内核的方法。请参阅 网格,又名循环中的内核

我们还使用 jax.experimental.pallas.BlockSpec 来指定如何构造每个内核调用的输入。请参阅 BlockSpec,又名如何将输入分块

在上面的流水线示例中,我们有 (512, 512) 形状的数组,并沿前导维度将它们拆分为两个 (256, 512) 形状的数组。在此流水线中,我们的 BlockSpec.block_shape 将为 (256, 512)。在第一次迭代中,我们希望选择 x1,在第二次迭代中,我们希望使用 x2。这可以使用以下 index_map 表示

def x_index_map(i):
  return (i, 0)

然后我们将构造 BlockSpec

block_spec = pl.BlockSpec((256, 512), x_index_map)

yzBlockSpec 将与 x 的相同。

整合在一起#

我们通过 gridin_specsout_specs 将这些参数提供给 pallas_callin_specs 对应于位置参数的元组,out_specs 对应于输出)。

def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:
  block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(2,)
  )(x, y)

add_matrices_pipelined(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

我们只在原始函数中添加了一点代码来添加自动流水线,但 BlockSpecgrid 完成了很多繁重的工作!

它是如何工作的?好吧,BlockSpec 提供了足够的信息来开始从 HBM **预取**输入的块到 VMEM 中。例如,如果我们正在启动 grid 的迭代 i,我们可以将 i + 1 传递到 index_map 函数中,以获取下一次迭代所需的块。然后我们可以启动这些块的异步复制。与输出类似,我们可以等待上一次迭代的输出被复制,然后再启动当前迭代输出的复制。

参数化流水线#

在我们的内核中参数化块形状是很常见的。块大小可能是优化 Pallas 内核性能时最重要的参数!它们让我们控制流水线(例如,选择较小的块会为我们的流水线循环添加更多迭代,其中每次迭代的工作量较少)。

此外,我们还可以沿第二个维度划分输入和输出(我们现在只沿第一个维度分割)。让我们编写一个更通用的内核来处理这两个特性。

def add_matrices_pipelined_2d(
    x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
  m, n = x.shape
  block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(m // bm, n // bn),
  )(x, y)

np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y
)

处理归约#

你将如何使用 pallas_call 实现类似 jnp.sum 的功能?具体来说,我们希望在归约维度上进行流水线操作。

以将 (8, 512, 512) 形状的数组归约为 (512, 512) 形状的数组为例。

x = jnp.ones((8, 512, 512))
jnp.sum(x, axis=0)
Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)

为了使用 pallas_call 执行此操作,我们可以使用大小为 (8,) 的网格,并在每次迭代 i 中将 x[i] 加载到 VMEM 中。然后我们可以将 x[i] 添加到输出 VMEM 缓冲区。让我们先天真地实现这一点。

# Warning: this implementation is incorrect!

def naive_sum_kernel(x_ref, o_ref):
  o_ref[...] += x_ref[...]

def naive_sum(x: jax.Array) -> jax.Array:
  grid, *out_shape = x.shape
  return pl.pallas_call(
      naive_sum_kernel,
      grid=grid,
      # None in `block_shape` means we pick a size of 1 and squeeze it away
      in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
      out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
  )(x)
naive_sum(x)
Array([[9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       ...,
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.]], dtype=float32)

请注意我们是如何设置 BlockSpec 的:我们正在将 (512, 512) 维度的全部内容加载到 VMEM 中(那里没有流水线),但在 index_map 中每次迭代都选择 x 的第 i 个维度。我们在块形状中对该维度使用 None,这表示我们正在从 x 中选择一个单例维度,我们希望在内核中将其挤压掉。因此,x_ref 在 VMEM 中也是 (512, 512) 形状的。

out_spec 使用 lambda i: (0, 0) 作为其 index_map,表示 o_ref 在流水线的过程中保持不变。这意味着我们可以通过读取和写入来在每次迭代中更新其值。或者可以吗?实际上有一个陷阱:**o_ref 最初是垃圾**,这意味着我们将累积到垃圾中。这将导致整个函数输出不正确的值!

因此,**每当我们在内核中进行归约时,我们需要确保初始化存储归约值的 Ref**。我们可以通过在迭代 0 时有条件地将值写入 out_ref 来完成此操作。我们可以使用辅助函数 pl.whenjax.lax.cond 的便捷包装器)和 pl.program_id(查询我们在网格轴中的哪个迭代中)来完成此操作。

def sum_kernel(x_ref, o_ref):
  @pl.when(pl.program_id(axis=0) == 0)
  def _():
    o_ref[...] = jnp.zeros_like(o_ref)

  o_ref[...] += x_ref[...]

def sum(x: jax.Array) -> jax.Array:
  grid, *out_shape = x.shape
  return pl.pallas_call(
      sum_kernel,
      grid=grid,
      # None in `block_shape` means we pick a size of 1 and squeeze it away
      in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
      out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
  )(x)

sum(x)
Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)

这个 sum 函数现在输出正确的值了!

关于 Pallas 中的归约,最后需要注意的一点是,**它们必须在网格的最小维度(最右侧维度)中完成**(在上面的示例中,我们的网格是一维的,因此我们正在对其最小维度进行归约)。这是因为 Pallas 使用 BlockSpecgrid 和内核函数生成的流水线**不会从 HBM 中读取输出**。一旦你将输出值写回 HBM,你就无法再次访问它。因此,你无法跨任何具有重新访问的网格维度进行归约,因此所有归约都需要在最右侧维度中进行。

Megacore 配置中的 TPU#

一些 TPU 芯片有两个 TensorCore,但在 JAX 用户看来,它们就像一个设备。这称为 “megacore”。单独的 TensorCore 有其自己单独的 VMEM、VREG、SMEM、SREG 和计算单元,但**共享 HBM**。

TPU Memory Space Cartoon (Megacore).png

从概念上讲,Megacore 中的 TPU 的行为类似于非常简单的 GPU,即它们只有两个线程。我们如何修改内核以同时利用两个 TensorCore?

基本思想是,如果我们的计算中存在易于并行化的维度,我们可以将这些维度分散到 TensorCore 上。我们可以通过向 pallas_call 提供一个名为 dimension_semantics 的注解来指示哪些维度是可并行化的。

def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
  block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(2,),
      compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",))
  )(x, y)

x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

dimension_semantics 应该是一个与 grid 长度相同的元组,其中每个条目是 "parallel""arbitrary""parallel" 向 Pallas 表明,对应于该维度的 for 循环的迭代可以独立执行,而不会影响程序的正确性。"arbitrary" 向 Pallas 表明,对于此网格维度不能做任何假设,因此它不能被并行化。

通过指定 dimension_semantics,我们现在在每个 TensorCore 上同时执行内核。 Pallas 将自动处理网格的分割。

请注意,Megacore 目前仅在 TPU v4 和 TPU v5p 上可用。 在其他平台上提供 dimension_semantics 注解不起作用,但是*不*指定它将导致仅使用一个 TensorCore(即使有多个可用)。

结论#

在本指南中,我们介绍了如何使用 pallas_callgridBlockSpec 来表达 TPU 管道。 我们介绍了如何通过多维网格来表达嵌套循环,以及如何通过在归约开始时初始化累加器来处理归约。 我们还学习了如何通过向内核添加注解来处理 Megacore。

留给读者的练习

  • 尝试实现一个 sum 内核,该内核也对其他维度进行流水线处理

  • add 内核和 sum 内核添加 megacore 支持。