流水线#

在本指南中,我们将介绍TPU中内存空间的工作方式,以及如何在Pallas中编写流水线,以重叠内存I/O和计算。

#@title Imports

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np

TPU及其内存空间#

一个TPU及其TensorCore由内存空间(数组可以驻留的地方)、寄存器(临时存储标量和数组值)和计算单元(使用寄存器中的值进行计算)组成。下图是一个TPU的示意图,其中xy是驻留在高带宽内存(HBM)中的数组

TPU Memory Space Cartoon.png

让我们更详细地讨论这个图中的组件

  • 内存空间:TPU具有高带宽内存(HBM),我们通常将其视为“设备内存”。还有向量内存(VMEM),一种用于存储向量和数组值的缓存,以及标量内存(SMEM),一种用于存储标量值的缓存。

  • 寄存器:TensorCore有两种主要类型的寄存器:向量寄存器(VREG)存储数组值,标量寄存器(SREG)存储标量值。值可以从其各自的缓存(VREG的VMEM和SREG的SMEM)加载到内存中。

  • 计算单元:TensorCore具有一个标量单元、向量单元(VPU)和矩阵单元(MXU),可以进行数值计算。计算单元对SREG和VREG中的值进行操作,并将值输出到这些寄存器中。

为了对我们驻留在HBM中的值xy进行向量化计算,我们需要:

  1. 将值xy复制到VMEM中。

  2. 将值从VMEM加载到VREG中。

  3. 使用VPU或MXU执行计算,将输出存储在VREG中。

  4. 将输出VREG中的值存储到VMEM中。

  5. 将VMEM中的输出值复制回HBM。

让我们实现一个执行此操作的Pallas函数!

def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):
  # Load x and y from VMEM into VREGs
  x_vregs = x_vmem_ref[:, :]
  y_vregs = y_vmem_ref[:, :]
  # Execute a vectorized add
  z_vregs = x_vregs + y_vregs
  # Store the output values in VREGs back into VMEM
  z_vmem_ref[:, :] = z_vregs


def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
  # pallas_call will first allocate scratch buffers for `x` and `y` in VMEM.
  # It will then copy `x` and `y` from HBM into VMEM.
  z = pl.pallas_call(
      add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
  )(x, y)
  # pallas_call will also copy the output from VMEM back into HBM.
  return z


x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

我们编写了两个函数:add_matrices_kerneladd_matrices

add_matrices_kernel使用驻留在VMEM中的Refs进行操作。从VMEM Ref加载会产生一个驻留在VREG中的值。VREG中的值就像jax.Array一样,我们可以使用jnpjax.lax操作来产生驻留在VREG中的新值。当我们产生我们想要返回的值时,我们会将它们存储在输出VMEM Ref中。

add_matrices函数对jax.Arrays进行操作,并返回一个jax.Array。在其中,我们将xy传递给pallas_callpallas_call负责将xy复制到VMEM中,并负责分配内核操作的VMEM缓冲区(包括分配输出VMEM缓冲区z_vmem_ref)。内核函数运行完成后,pallas_call还将z_vmem_ref中的值复制到HBM,从而产生输出jax.Array

使用VMEM/SMEM的限制#

Pallas允许访问较低级别的内存空间,如VMEM和SMEM,但编写利用它们的内核会增加一些考虑因素。

  1. 内存容量。VMEM和SMEM都很小!v4 TPU上的VMEM只有16MiB,而SMEM的范围在几十到几百KiB之间。如果我们的数组太大,我们甚至无法将它们全部放入VMEM中。作为参考,一个f32[2048, 2048]数组是16MiB,因此我们上面的内核无法扩展到中等大小的数组之外。

  2. 内存带宽。与大多数计算指令相比,在HBM和VMEM之间复制需要很长时间。上面的add_matrices函数可能花费更多的时间在HBM和VMEM之间复制,而不是实际执行加法本身。

考虑到这两个限制,我们将不得不重新考虑如何从我们的TPU中获得性能。

入门:流水线#

流水线化我们的计算提供了一种同时处理内存容量和带宽限制的方法。我们所说的流水线是什么意思?

目标是:并行地在HBM和VMEM之间复制同时利用我们的计算单元。从表面上看,这很困难,因为在我们上面的程序中,我们在开始对它们进行任何计算之前复制所有xy,从而在复制和计算之间创建了依赖关系。

但是,如果我们可以将我们的计算分解为几个子计算(例如,当我们添加两个矩阵时,我们可以将其表示为将原始矩阵的“块”加在一起),我们现在可以将其中一个子计算的复制与另一个子计算的计算重叠。让我们看一个简单的例子

假设我们将数组xy拆分为x1, x2y1, y2(例如,沿前导轴拆分,每个输入得到两个(256, 512)数组)。我们现在可以执行以下流水线计算。

  1. x1y1复制到VMEM中。

  2. 开始将x2y2复制到VMEM中

  3. x1, y1从VMEM加载到VREG中。

  4. 使用计算单元执行z1 = x1 + y1

  5. z1存储到VMEM中。

  6. 开始将z1从VMEM复制回HBM。

  7. 等待直到x2, y2复制到VMEM中。

  8. x2, y2从VMEM加载到VREG中。

  9. 使用计算单元执行z2 = x2 + y2

  10. z2存储到VMEM中。

  11. 等待直到z1复制到HBM中。

  12. 开始将z2从VMEM复制回HBM。

  13. 等待直到z2复制到HBM中。

在我们进行计算的任何时候,我们都在异步地复制某些内容。这意味着一些复制的时间没有被浪费。

用于确定流水线计算效率的两个最重要数字是a)我们需要执行多少浮点运算(FLOP)和b)我们需要复制多少字节来执行该计算。这两个数字的比率(FLOP/内存使用量)称为运算的算术强度,它决定了我们的流水线是受计算限制还是受内存限制。

Pallas中的流水线#

我们如何在Pallas中实现如上所示的流水线?这似乎是一个复杂的异步数据操作和执行内核的序列,手动实现会很麻烦。不要害怕!Pallas提供了一个用于表达流水线的API,而无需太多样板代码,即通过gridBlockSpec

请注意,在上面的流水线示例中,我们多次执行相同的逻辑:步骤3-5和8-10都执行相同的操作,只是在不同的输入上。 jax.experimental.pallas.pallas_call()提供了一种通过使用grid参数多次执行内核的方法。请参阅网格,又名循环中的内核

我们还使用jax.experimental.pallas.BlockSpec来指定如何构造每个内核调用的输入。请参阅BlockSpec,又名如何对输入进行分块

在上面的流水线示例中,我们有(512, 512)形状的数组,并沿着前导维度将其拆分为两个(256, 512)形状的数组。在此流水线中,我们的BlockSpec.block_shape将是(256, 512)。在第一次迭代中,我们希望选择x1,而在第二次迭代中,我们希望使用x2。这可以用以下index_map表示

def x_index_map(i):
  return (i, 0)

然后,我们将构造BlockSpec

block_spec = pl.BlockSpec((256, 512), x_index_map)

yzBlockSpecx的相同。

组合起来#

我们通过 gridin_specsout_specs 将这些参数传递给 pallas_call ( in_specs 对应于位置参数的元组,而 out_specs 对应于输出)。

def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:
  block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(2,)
  )(x, y)

add_matrices_pipelined(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

我们只在原始函数中添加了少量代码来添加自动流水线,但是 BlockSpecgrid 完成了很多繁重的工作!

它是如何工作的? 好吧,BlockSpec 提供了足够的信息来开始从 HBM 预取输入块到 VMEM。例如,如果我们正在开始 grid 的第 i 次迭代,我们可以将 i + 1 传递给 index_map 函数,以获取下一次迭代所需的块。然后,我们可以启动这些块的异步复制。类似地,对于输出,我们可以等待前一次迭代的输出被复制,然后再开始复制当前迭代的输出。

参数化流水线#

在我们的内核中参数化块形状是很常见的。块大小可能是优化 Pallas 内核性能时最重要的参数!它们使我们能够控制流水线(例如,选择较小的块会向我们的流水线循环添加更多迭代,其中每次迭代的工作量较少)。

此外,我们还可以沿着第二个维度分割输入和输出(我们现在只沿着第一个维度进行分割)。让我们编写一个更通用的内核来处理这两个功能。

def add_matrices_pipelined_2d(
    x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
  m, n = x.shape
  block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(m // bm, n // bn),
  )(x, y)

np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y
)

处理归约#

如何使用 pallas_call 实现类似 jnp.sum 的操作?具体来说,我们希望在归约维度上进行流水线处理。

以将 (8, 512, 512) 形的数组归约为 (512, 512) 形的数组为例。

x = jnp.ones((8, 512, 512))
jnp.sum(x, axis=0)
Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)

为了使用 pallas_call 实现这一点,我们可以使用大小为 (8,) 的网格,并在每次迭代 i 中将 x[i] 加载到 VMEM 中。然后,我们可以将 x[i] 添加到输出 VMEM 缓冲区。让我们先天真地实现这一点。

# Warning: this implementation is incorrect!

def naive_sum_kernel(x_ref, o_ref):
  o_ref[...] += x_ref[...]

def naive_sum(x: jax.Array) -> jax.Array:
  grid, *out_shape = x.shape
  return pl.pallas_call(
      naive_sum_kernel,
      grid=grid,
      # None in `block_shape` means we pick a size of 1 and squeeze it away
      in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
      out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
  )(x)
naive_sum(x)
Array([[9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       ...,
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.]], dtype=float32)

请注意我们如何设置 BlockSpec:我们将整个 (512, 512) 维度加载到 VMEM 中(那里没有流水线),而是在 index_map 中每次迭代选择 x 的第 i 个维度。我们在块形状中使用 None 来表示我们正在从 x 中选择一个单例维度,我们希望在内核中将其挤压掉。因此,x_ref 在 VMEM 中也是 (512, 512) 形的。

out_spec 使用 lambda i: (0, 0) 作为其 index_map,表示 o_ref 在整个流水线过程中保持不变。这意味着我们可以通过读取和写入它的值来更新每次迭代的值。或者它可以吗?实际上,有一个问题:o_ref 最初是垃圾数据,这意味着我们将累积到垃圾中。这将导致整体函数输出错误的值!

因此,每当我们在内核中进行归约时,我们需要确保初始化存储归约值的 Ref。我们可以通过在迭代 0 时有条件地向 out_ref 写入值来完成此操作。我们可以使用辅助函数 pl.when,它是 jax.lax.cond 的一个方便的包装器,以及 pl.program_id,它查询我们在网格轴中的哪个迭代。

def sum_kernel(x_ref, o_ref):
  @pl.when(pl.program_id(axis=0) == 0)
  def _():
    o_ref[...] = jnp.zeros_like(o_ref)

  o_ref[...] += x_ref[...]

def sum(x: jax.Array) -> jax.Array:
  grid, *out_shape = x.shape
  return pl.pallas_call(
      sum_kernel,
      grid=grid,
      # None in `block_shape` means we pick a size of 1 and squeeze it away
      in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
      out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
  )(x)

sum(x)
Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)

现在,此 sum 函数会输出正确的值!

关于 Pallas 中的归约,最后需要注意的一点是,它们必须在我们网格的最次要(最右边)维度中完成(在上面的示例中,我们的网格是一维的,因此我们正在归约其最次要维度)。这是因为 Pallas 使用 BlockSpecgrid 和内核函数生成的流水线不会从 HBM 读回输出。一旦将输出值写回 HBM,就无法再次访问它。因此,不能跨具有任何重新访问的网格维度进行归约,因此所有归约都需要发生在最右边的维度中。

Megacore 配置中的 TPU#

一些 TPU 芯片具有两个 TensorCore,但在 JAX 用户看来像一个设备。这称为“megacore”。单独的 TensorCore 具有自己单独的 VMEM、VREG、SMEM、SREG 和计算单元,但共享 HBM

TPU Memory Space Cartoon (Megacore).png

从概念上讲,Megacore 中的 TPU 的行为类似于非常简单的 GPU,即它们只有两个线程。我们如何修改内核以同时利用两个 TensorCore?

基本思想是,如果我们的计算中有可并行处理的维度,我们可以将这些维度拆分到 TensorCore 中。我们可以通过向 pallas_call 提供一个名为 dimension_semantics 的注释来指示哪些维度是可并行化的。

def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
  block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(2,),
      compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",))
  )(x, y)

x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

dimension_semantics 应该是一个与 grid 长度相同的元组,其中每个条目是 "parallel""arbitrary""parallel" 向 Pallas 表示,对应于该维度的 for 循环的迭代可以独立执行,而不会影响程序的正确性。"arbitrary" 向 Pallas 表示,不能对该网格维度做出任何假设,因此不能并行化。

通过指定 dimension_semantics,我们现在在每个 TensorCore 上同时执行内核。Pallas 将自动处理网格的拆分。

请注意,Megacore 目前仅在 TPU v4 和 TPU v5p 上可用。提供 dimension_semantics 注释在其他平台上是无操作的,但是指定它将导致仅使用一个 TensorCore(即使有多个可用)。

结论#

在本指南中,我们介绍了如何使用 pallas_callgridBlockSpec 来表达 TPU 流水线。我们介绍了如何通过多维网格表达嵌套循环,以及如何通过在归约开始时初始化累加器来处理归约。我们还学习了如何通过向内核添加注释来处理 Megacore。

留给读者的练习

  • 尝试实现一个也对其他维度进行流水线处理的 sum 内核

  • add 内核和 sum 内核添加 megacore 支持。