软件流水线#

软件流水线是性能优化中的一项重要技术,它通过重叠多个异步操作来提高效率,即使这些操作之间存在数据依赖。在编写内核的背景下,最常见的流水线形式是让通信和内存传输与计算重叠进行,以确保硬件加速器在等待数据到达时不会停滞。因此,本教程将只关注通信-计算流水线的问题。我们将首先从概念上介绍该问题,概述用于编写流水线的 Pallas API,并通过使用该 API 的一些实际示例进行讲解。

本教程仅涵盖流水线的概念基础。有关特定于平台的参考资料,请参阅TPU 流水线,或Mosaic GPU 流水线

import jax
from jax import numpy as jnp
from jax.experimental import pallas as pl
import numpy as np

内存层次结构#

从概念上理解流水线的第一步是了解可用的不同形式的内存以及它们之间的权衡。大多数硬件架构(包括 CPU、GPU 和 TPU)都利用各种内存空间,这些空间在容量与延迟/带宽之间进行权衡。对于 Pallas 而言,我们通常关注寄存器、SRAM、DRAM 以及可能的网络通信。

  • 寄存器是物理上最接近处理器​​的内存,通常在进行任何计算之前,值必须直接加载到寄存器中。

  • SRAM(在 GPU 上也称为共享内存/L1 和 L2 缓存,在 TPU 上称为 VMEM)也相当靠近处理器,但容量比寄存器更大。现代机器学习加速器上的 SRAM 通常在 10-100MB 范围内(TPU v5p 包含 96MB 的 VMEM,H100 GPU 包含约 30MB 的 L1 缓存和 50MB 的 L2)。可以合理地预期访问 SRAM 的延迟是访问寄存器的 10 倍左右。

  • DRAM(也称为 HBM)的容量比 SRAM 大得多,对于现代机器学习加速器来说,通常在 10-100GB 范围内。然而,与访问 SRAM 相比,访问它的延迟大约是 10 倍。

  • 当单个设备上的 DRAM 大小不足或我们希望利用并行计算时,网络通信变得至关重要。本教程不涵盖分布式流水线,但请参阅分布式 TPU 内核指南,了解如何在多个设备上编写流水线。

memory_hierarchy

为了对存储在 HBM 中的值 X 和 Y 执行计算,我们需要

  1. 将值 x 和 y 复制到 SRAM 中。

  2. 将值从 SRAM 加载到寄存器中。

  3. 执行计算并将结果存储到寄存器中。

  4. 将输出寄存器中的值存储到 SRAM 中。

  5. 将 SRAM 中的输出值复制回 HBM。

我们来实现在 Pallas 中完成此操作的函数!

# Note: This is a TPU example.

def add_matrices_kernel(x_sram_ref, y_sram_ref, z_sram_ref):
  # Load x and y from SRAM into registers
  x_regs = x_sram_ref[:, :]
  y_regs = y_sram_ref[:, :]
  # Execute a vectorized add
  z_regs = x_regs + y_regs
  # Store the output values in registers back into SRAM
  z_sram_ref[:, :] = z_regs


def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
  # pallas_call will first allocate scratch buffers for `x` and `y` in SRAM.
  # It will then copy `x` and `y` from HBM into SRAM.
  z = pl.pallas_call(
      add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
  )(x, y)
  # pallas_call will also copy the output from SRAM back into HBM.
  return z


x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

我们编写了两个函数:add_matrices_kerneladd_matrices

add_matrices_kernel 使用存储在 SRAM 中的 Refs 进行操作。从 SRAM Ref 加载会产生存储在寄存器中的值。寄存器中的值表现得像 jax.Arrays,我们可以对它们使用 jnpjax.lax 操作来生成存储在寄存器中的新值。当我们生成要返回的值时,我们将其存储到输出 SRAM Ref 中。

函数 add_matricesjax.Array 进行操作并返回一个 jax.Array。在其中,我们将 xy 传递给 pallas_callpallas_call 负责将 xy 复制到 SRAM 中,并分配内核操作所需的 SRAM 缓冲区(包括分配输出 SRAM 缓冲区 z_vmem_ref)。内核函数运行结束后,pallas_call 还会将 z_vmem_ref 中的值复制到 HBM,从而得到一个输出 jax.Array

Pallas 暴露了对 SRAM 等低级内存空间的访问,但编写高性能内核需要更仔细地利用各种内存空间。例如,我们需要同时考虑

  • 内存容量。SRAM 很小!如果我们的数组太大,上述内核将无法工作,因为我们无法将输入放入 SRAM。作为参考,一个 f32[2048, 2048] 数组是 16MiB,因此我们的上述内核无法处理中等大小以上的数组。

  • 内存带宽。从 HBM 复制到 SRAM 或从 SRAM 复制到 HBM 需要很长时间,至少与大多数计算指令相比是这样。上面的 add_matrices 函数很可能在 HBM 和 SRAM 之间复制数据所花费的时间比实际执行加法本身还要多。

考虑到这两个限制,我们必须重新思考从加速器中获得性能的策略。

流水线基础#

我们如何利用层次结构中每种类型内存的优势,并能够在 HBM 中存储的大型数组上进行操作,同时仍然利用快速 SRAM 进行计算?流水线是一种非常通用的编程模式,它将使我们能够做到这一点,但它需要将您的问题转换为可以并行重叠的更小的子问题。

流水线的第一步是将我们的问题划分为可以适应 SRAM 的更小的子问题。例如,逐元素操作可以通过一次操作源数组的一个切片轻松转换,这会产生以下 3 个步骤(也称为阶段):

  1. copy_in:将切片 A[i] 从 HBM 复制到 SRAM X

  2. compute:将 X 加载到寄存器中,计算结果,并存储在 SRAM Y 中。

  3. copy_out:将结果 Y 复制回 HBM A[i]

请注意,步骤 1-3 之间存在数据依赖性,我们无法轻易地重叠它们,因为我们需要步骤 (1) 完成后才能开始步骤 (2),依此类推。然而,子问题的多次调用之间没有数据依赖性——也就是说,我们可以在执行块 A[i] 的步骤 (2) 和块 A[i-1] 的步骤 (3) 的同时,执行块 A[i+1] 的步骤 (1)。

pipelining_example

上图描绘了一个理想化的流水线程序如何跨时间调度。关键在于,在内核的大部分时间里,复制操作与计算操作并行执行,这意味着我们理想情况下可以利用计算来“隐藏”HBM/SRAM 之间传输的成本,并尽可能保持处理器繁忙。

初始启动时间和最终拆卸时间被称为“气泡”(bubbles),在此期间只有部分阶段在执行,而流水线正在被“填充”或“排空”。大部分时间花费在流水线的“稳态”阶段,每个流水线阶段在子问题的不同迭代中并行执行。虽然更通用的流水线方法的目标是实现 N 路并行(N 是阶段数),但在内核流水线中,我们通常受限于内存带宽或处理速度。因此,内核流水线的典型目标是充分利用处理器的 FLOPs/s,这意味着在任何时间点,总有一个 compute 块处于活动状态。在上图中,计算块在 8 个时间槽中有 6 个处于活动状态,假设我们在每个计算时间槽中都充分利用了处理器,那么我们将实现 75% 的处理器利用率。

推导双缓冲流水线#

现在我们来看看如何在伪代码中实现流水线。考虑以下逐元素程序,我们使用 copy_in 指令从 HBM 加载值 (A[i]),将结果加 1,然后使用 copy_out 将结果存储回 HBM。

for i in range(N):
  copy_in(A[i], X)
  Y = X + 1
  copy_out(Y, A[i])

这种方法的问题在于 copy_incopy_out 通常是阻塞操作。因此,我们被迫等待复制完成,而 GPU/TPU 处于空闲状态;然后执行计算,而内存处于空闲状态。我们希望做的是在执行当前循环的计算时,异步地“预取”下一个循环迭代所需的输入值,以便计算和内存通信同时发生。

为了理解我们将要进行的的代码转换,我们展开 N=4 的循环,并将复制指令分解为单独的 copy_startcopy_wait 操作,以便能够表达异步性。

  # Itr 1
  copy_in_start(A[0], X)
  copy_in_wait(X)
  Y = X + 1
  copy_out_start(Y, A[0])
  copy_out_wait(Y)

  # Itr 2
  copy_in_start(A[1], X)
  copy_in_wait(X)
  Y = X + 1
  copy_out_start(Y, A[1])
  copy_out_wait(Y)

  # Itr 3
  copy_in_start(A[2], X)
  copy_in_wait(X)
  Y = X + 1
  copy_out_start(Y, A[2])
  copy_out_wait(Y)

  # Itr 4
  copy_in_start(A[3], X)
  copy_in_wait(X)
  Y = X + 1
  copy_out_start(Y, A[3])
  copy_out_wait(Y)

一旦循环被展开,流水线转换就只是尽可能早地发出 copy_start 指令,并尽可能晚地(就在我们需要值之前)等待 copy_wait 值。然而,在循环的当前状态下,通过 X 存在一个假的 数据依赖——我们不能在异步复制到 X 的同时使用它进行计算,否则可能会出现竞态条件。因此,我们可以使用一种多缓冲技术,为每个输入 X 和每个输出 Y 保留 2 个缓冲区。使用 2 个缓冲区,我们可以将 copy_in_start 提前一个迭代(3 个缓冲区可以提前 2 个迭代,以此类推),我们将循环改写如下:

  # Prologue
  copy_in_start(A[0], X[0])
  
  # Itr 1
  copy_in_start(A[1], X[1])
  copy_in_wait(X[0])
  Y[0] = X[0] + 1
  copy_out_start(Y[0], A[0])
  copy_out_wait(Y[0])

  # Itr 2 - Steady state
  copy_in_start(A[2], X[0])
  copy_in_wait(X[1])
  Y[1] = X[1] + 1
  copy_out_start(Y[1], A[1])
  copy_out_wait(Y[1])

  # Itr 3 - Steady state
  copy_in_start(A[3], X[1])
  copy_in_wait(X[0])
  Y[0] = X[0] + 1
  copy_out_start(Y[0], A[2])
  copy_out_wait(Y[0])

  # Itr 4 - No copy-in
  copy_in_wait(X[1])
  Y[1] = X[1] + 1
  copy_out_start(Y[1], A[3])
  copy_out_wait(Y[1])

接下来,我们可以尽可能晚地将 copy_out_wait 推迟,就在后续循环迭代中需要写入 Y 之前。

  # Prologue
  copy_in_start(A[0], X[0])
  
  # Itr 1
  copy_in_start(A[1], X[1])
  copy_in_wait(X[0])
  Y[0] = X[0] + 1
  copy_out_start(Y[0], A[0])

  # Itr 2 - Steady state
  copy_in_start(A[2], X[0])
  copy_in_wait(X[1])
  Y[1] = X[1] + 1
  copy_out_start(Y[1], A[1])
  copy_out_wait(Y[0])

  # Itr 3 - Steady state
  copy_in_start(A[3], X[1])
  copy_in_wait(X[0])
  Y[0] = X[0] + 1
  copy_out_start(Y[0], A[2])
  copy_out_wait(Y[1])

  # Itr 4 - No copy-in
  copy_in_wait(X[1])
  Y[1] = X[1] + 1
  copy_out_start(Y[1], A[3])
  copy_out_wait(Y[0])

  # Epilogue
  copy_out_wait(Y[1])

最后,将我们的循环重新卷回 for 循环,我们得到以下流水线循环:

# Prologue
copy_in_start(A[0], X[0])

# Main loop
for i in range(N):
  cur_slot = i % 2
  next_slot = (i + 1) % 2

  if i < N:
    copy_in_start(A[i+1], X[next_slot])
  
  copy_in_wait(X[cur_slot])
  Y[cur_slot] = X[cur_slot] + 1
  copy_out_start(Y[cur_slot], A[i])

  if i > 0:
    copy_out_wait(Y[next_slot])

# Epilogue
copy_out_wait(Y[1])

如果我们要将此循环推广到处理更广泛的计算集,请注意我们本质上需要向流水线指定 3 条信息:

  • 网格 (grid),即 for 循环的边界,它指定要计算的子问题数量。在我们的示例中,我们有一个大小为 (N,) 的一维网格。

  • 内核 (kernel),即输入加载到 SRAM 后实际发生的计算。在我们的示例中,我们执行了逐元素加法 Y = X + 1

  • 数据切片 (data_slices),它将子问题映射到 HBM 缓冲区中的相应切片。在我们的示例中,数据切片是恒等函数 lambda i: i

通过允许用户指定这些信息,我们可以按照这种模式编写各种各样的程序。

def double_buffered_pipeline(
    grid: tuple[int, ...],
    kernel: Callable,
    in_slices: Callable,
    out_slices: Callable):
  # Prologue
  copy_in_start(in_hbm[in_slices(0)], in_sram[0])

  # Main loop
  grid_size = prod(grid)
  for i in range(grid_size):
    cur_slot = i % 2
    next_slot = (i + 1) % 2
    if (i + 1) < grid_size:
      copy_in_start(in_hbm[in_slices(i+1)], in_sram[next_slot])
    copy_in_wait(in_sram[cur_slot])

    kernel(in_sram[cur_slot], out_ram[cur_slot])

    copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)])
    if i > 0:
      copy_out_wait(out_sram[next_slot])

  # Epilogue
  last_slot = (grid_size - 1) % 2
  copy_out_wait(out_sram[last_slot])

现在我们已经了解了如何手动实现流水线循环,接下来我们来看看如何使用 Pallas API。

Pallas 流水线 API#

Pallas 提供了一个流水线 API,它抽象掉了维护多个缓冲区以及重叠异步通信与计算的繁琐工作。此 API 的基础知识在Pallas 快速入门中介绍,因此我们在此简要回顾一下 API 以便完整性,并讨论使用流水线时出现的一些潜在问题。

网格#

程序网格 (grid) 是一个整数元组,它以数组形式指定子问题的数量。流水线的结构可以被解释为嵌套的 for 循环,其中包含每个循环的边界。

# For grid (N, M, K)
for n in range (N):
  for m in range(M):
    for k in range(K):
      kernel()

内核将总共被调用 prod(grid) 次。有关更多详细信息,请参阅网格和块规格

块规格#

BlockSpec 指定了每个子问题复制到内核的数据的大小和切片。pl.BlockSpec 的基本构造函数涉及指定 block_shape(数据切片的大小)和 index_map(一个函数,它接收当前子问题的程序 ID 并输出源缓冲区中的分块索引)。分块索引指定在每次迭代中要复制哪个块,假设源缓冲区已被切割成 block_shape 形状的块。memory_space 参数指定要将输入复制到哪个内存空间——默认情况下这将是 SRAM。

pl.BlockSpec(
  block_shape: tuple[int, ...],
  index_map: Callable,
  memory_space: pl.MemorySpace
)

内核的每个输入和每个输出都应该有一个 BlockSpec。有关更多详细信息,请参阅网格和块规格

内核#

内核函数指定在每个子问题上执行什么计算。内核函数不应返回任何输出,而是所有输出都应写入传递到内核的输出缓冲区中。默认情况下,所有输入和输出缓冲区都是 SRAM 缓冲区(除非用户通过在相应的 BlockSpec 上指定 memory_space 来覆盖此行为)。

def kernel(*input_buffers, *output_buffers):
  # ... perform compute
  # ... store result into output buffers

可以使用 pl.program_id(grid_axis: int) 在内核内部查询当前子问题的索引。

Pallas 调用#

函数 pl.pallas_call 是 Pallas 的主要入口点,当提供网格和 BlockSpecs 时,它会执行流水线操作。它具有以下签名:

def pallas_call(
  kernel,
  grid: tuple[int, ...],
  in_specs: Sequence[PyTree[BlockSpec]],
  out_specs: PyTree[BlockSpec],
  out_shape: PyTree[jax.ShapeDtypeStruct],
) -> Callable:

pallas_call 将返回一个可调用函数,该函数在用输入值调用时,将返回与 out_shape 具有相同形状的输出。

in_specsout_specsout_shape 是各自元素类型的 PyTrees。in_specs 的 PyTrees 应与提供给内核的输入缓冲区匹配,out_specsout_shape 的 PyTrees 也应匹配。

示例 - 逐元素内核回顾#

让我们重新审视本教程开头部分的 add_matrices_kernel,但这次使用流水线。我们将添加两个形状为 f32[4096, 4096] 的输入数组,它们位于 HBM 中。作为子问题,我们将输入切割成 block_shape=(512, 512) 的块,并且每次在内核中只将两个块相加。由于加法是逐元素的,每个 index_map 都是相同的,并在第 i, j 次迭代中选择出第 i, j 个块。

# Note: This is a TPU example.

total_shape = (4096, 4096)
block_shape = (512, 512)

def add_matrices_pipelined_kernel(x_ref, y_ref, o_ref):
  o_ref[...] = x_ref[...] + y_ref[...]

def add_matrices_pipelined(x: jax.Array, y: jax.Array):
  return pl.pallas_call(
    add_matrices_pipelined_kernel,
    grid=tuple(total // block for (total, block) in zip(total_shape, block_shape)),
    in_specs=[
      pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),
      pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j))
    ],
    out_specs=pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),
    out_shape=jax.ShapeDtypeStruct(total_shape, dtype=jnp.float32),
  )(x, y)

x = jax.random.uniform(jax.random.key(0), total_shape, dtype=jnp.float32)
y = jax.random.uniform(jax.random.key(1), total_shape, dtype=jnp.float32)
result = add_matrices_pipelined(x, y)
np.testing.assert_array_equal(
    result, x + y
)

事实证明,使用此 API 编写流水线内核的代码行数并不比编写我们最初的朴素加法内核多很多!

内核参数化#

在我们的内核中参数化块形状是很常见的。块大小也许是优化 Pallas 内核性能时最重要的参数!它们让我们能够控制流水线(例如,选择更小的块会为我们的流水线循环增加更多迭代,其中每次迭代的工作量更少)。我们来编写一个实现此功能的函数。

def add_matrices_pipelined_param(
    x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
  m, n = x.shape
  block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(m // bm, n // bn),
  )(x, y)

np.testing.assert_array_equal(
    add_matrices_pipelined_param(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_param(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_param(x, y, bm=512, bn=512), x + y
)

潜在问题#

虽然流水线提供了一种接近于在循环中简单调用内核函数的心理模型,但由于使用了未完全向用户隐藏的中间缓冲区,会出现一些潜在问题,并可能导致细微的错误。

缓冲区回顾#

通常,一个好的经验法则是:传递给内核函数的输入缓冲区应被解释为只读,而输出缓冲区是只写

在大多数情况下,写入输入和从输出读取会导致不正确。这是因为传递给内核的 SRAM 缓冲区只包含底层 HBM 缓冲区中数据的副本。如果更新了输入 SRAM 缓冲区,更新后的结果永远不会写回 HBM;如果更新了输出缓冲区,其更新后的值永远不会读入 SRAM。这个问题类似于在使用缓存时普遍遇到的数据陈旧问题。

有两种情况下缓冲区支持读写:累加(接下来讨论)以及通过将 input_output_aliases 参数传递给 pallas_call 来将一对输入和输出缓冲区标记为输入-输出别名。

归约与累加#

归约/累加应仅在网格的最后(最内层)维度上执行,并且缓冲区应首先手动初始化。

归约是流水线支持对输出缓冲区进行读写的少数情况之一,但其工作原理是微妙的。Pallas 流水线发射器会执行一项优化:如果两次连续迭代之间的数据切片相同,流水线将不会对该缓冲区发出 copy_in/copy_out 操作。这意味着在先前迭代中使用的同一个 SRAM 缓冲区将在后续迭代中再次传递到内核,因此对输出缓冲区发出的任何写入都将在下一次迭代中可见。一旦数据切片发生变化,最终累加的 SRAM 缓冲区将被写入 HBM。这也是为什么归约必须在网格的最后维度上执行——我们希望在最内层循环中,当输出缓冲区在 SRAM 中时完成所有累加,然后将其写入 HBM,并且不再触碰该输出块。

作为一个具体示例,我们考虑执行以下计算,将一个 (8, 1024, 1024) 数组沿第一个轴归约成一个 (1024, 1024) 数组。

x = jnp.ones((8, 1024, 1024))
jnp.sum(x, axis=0)
Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)

要使用 pallas_call 执行此操作,我们可以使用大小为 (8,) 的网格,并在每次迭代 i 中将 x[i] 加载到 SRAM 中。然后我们可以将 x[i] 添加到一个输出 SRAM 缓冲区中。我们先来朴素地实现它。

# Note: This is a TPU example.

# Warning: this implementation is incorrect!
def incorrect_sum_kernel(x_ref, o_ref):
  o_ref[...] += x_ref[...]

def incorrect_sum(x: jax.Array,
              block_size: tuple[int, ...] = (256, 256)) -> jax.Array:
  reduction_size, *out_shape = x.shape
  grid = (reduction_size, *(out // blk for out, blk in zip(out_shape, block_size)))
  return pl.pallas_call(
      incorrect_sum_kernel,
      grid=grid,
      # None in `block_shape` means we pick a size of 1 and squeeze it away
      in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (i, j, k))],
      out_specs=pl.BlockSpec(block_size, lambda i, j, k: (j, k)),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
  )(x)

result = incorrect_sum(x)
print(result)
[[65. 65. 65. ... 66. 66. 66.]
 [65. 65. 65. ... 66. 66. 66.]
 [65. 65. 65. ... 66. 66. 66.]
 ...
 [71. 71. 71. ... 72. 72. 72.]
 [71. 71. 71. ... 72. 72. 72.]
 [71. 71. 71. ... 72. 72. 72.]]

这个结果是完全错误的!

这个内核中有两个错误。首先,我们沿着第一个网格维度进行累加而不是最后一个网格维度。其次,o_ref 最初包含垃圾值,因此在开始累加之前我们需要将其初始化为零。

修正这两个问题后,我们得到了以下正确的内核。在这个新内核中,我们使用 @pl.when 来创建一个条件,检查当程序 ID 沿归约轴为 0 时,表明我们开始累加到一个新的输出块。我们还将归约维度移动到了 grid 的最后一个轴。

# Note: This is a TPU example.

def correct_sum_kernel(x_ref, o_ref):
  @pl.when(pl.program_id(2) == 0)
  def _():
    o_ref[...] = jnp.zeros_like(o_ref)
  o_ref[...] += x_ref[...]

def correct_sum(x: jax.Array,
              block_size: tuple[int, ...] = (256, 256)) -> jax.Array:
  reduction_size, *out_shape = x.shape
  # We moved the reduction to the last axis of the grid.
  grid = (*(out // blk for out, blk in zip(out_shape, block_size)), reduction_size)
  return pl.pallas_call(
      correct_sum_kernel,
      grid=grid,
      # None in `block_shape` means we pick a size of 1 and squeeze it away
      in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (k, i, j))],
      out_specs=pl.BlockSpec(block_size, lambda i, j, k: (i, j)),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
  )(x)

result = correct_sum(x)
print(result)
[[8. 8. 8. ... 8. 8. 8.]
 [8. 8. 8. ... 8. 8. 8.]
 [8. 8. 8. ... 8. 8. 8.]
 ...
 [8. 8. 8. ... 8. 8. 8.]
 [8. 8. 8. ... 8. 8. 8.]
 [8. 8. 8. ... 8. 8. 8.]]

性能分析#

流水线内核的性能如何?这个问题会因硬件瓶颈所在而异。我们通常关注 3 个量:

  • 内存延迟 \(α\),即内存传输的最小延迟。

  • 内存带宽 \(β\),即我们可以从 HBM 传输到 SRAM 的速率(字节/秒)。

  • FLOP/s \(F\),即每秒浮点运算次数,是处理器每秒可以执行的计算次数。

如果处理速度 FLOPs/s 是瓶颈,我们称一个程序为计算密集型 (compute-bound);如果带宽或延迟是瓶颈,则称为内存密集型 (memory-bound)。通常,我们的目标是优化内核使其成为计算密集型,这意味着我们正在充分利用硬件的所有可用处理能力。

假设我们正在运行一个程序,该程序在每次内核迭代中需要 \(X\) 字节的内存传输,并执行 \(Y\) 次浮点运算。\(X\)\(Y\) 的比率因计算类型而异——对于逐元素操作(如加法或乘法),它们会同等比例扩展。然而,对于矩阵乘法等操作,计算量随问题大小呈立方级增长,而内存量呈平方级增长。

计算密集型情况下,一个运行 \(N\) 次迭代的流水线将花费 \((\alpha + X/\beta) + N (Y/F)\) 秒,其中第一项表示初始气泡的成本(如果末尾也有气泡,则乘以 2),第二项表示流水线稳态的总时间。假设 N 足够大且有足够的工作来产生一个长的流水线,则运行时的主导项是 \(F\),即加速器的处理速度。

pipelining_compute

内存密集型情况下,识别问题是延迟瓶颈还是带宽瓶颈很有用。如果带宽是瓶颈,那么总运行时间将需要 \(\alpha + N(X / \beta)\) 秒。与延迟密集型情况相比,内存复制是串行发生的,因为带宽已经饱和。成为内存密集型通常不是理想情况,因为处理器会有空闲时间,并且在大多数硬件配置中,内存带宽 \(\beta\) 比处理速度 \(F\) 慢几个数量级。

pipelining_bandwidth

如果瓶颈是具体的延迟而不是带宽,可以通过插入额外的流水线阶段来解决问题,但这会增加存储更多缓冲区所需的 SRAM。在足够多的阶段下,问题将再次变为计算密集型或带宽密集型,具体取决于我们在流水线稳态阶段首先遇到的瓶颈。然而,多阶段流水线的缺点是气泡的大小与阶段数量成正比,因此确保流水线足够长以使气泡不会占据总运行时间的大部分至关重要。

pipelining_latency

TPU 上的 Pallas 仅支持双缓冲,因为 TPU 程序可以操作更大的块大小,并且双缓冲通常足以弥补延迟。在 GPU 上,流水线阶段的数量可以在 Triton(通过 CompilerParams)和 Mosaic GPU 后端(通过流水线发射器的参数)中指定。有关更多详细信息,请参阅特定于平台的流水线文档。