软件流水线#

软件流水线是一种重要的性能优化技术,通过重叠多个异步操作来提高效率,即使这些操作之间存在数据依赖。在编写内核的上下文中,流水线最常见的形式涉及将通信和内存传输与计算重叠,从而使硬件加速器在等待数据到达时不会停顿。因此,本教程将仅关注通信-计算流水线问题。我们将首先从概念上介绍该问题,概述用于编写流水线的 Pallas API,并介绍使用该 API 的一些实际示例。

本教程仅涵盖流水线的概念基础。有关特定平台的参考,请参阅 TPU 流水线Mosaic GPU 流水线

import jax
from jax import numpy as jnp
from jax.experimental import pallas as pl
import numpy as np

内存层次结构#

从概念上理解流水线的第一步是理解可用的不同内存形式及其权衡。大多数硬件架构(包括 CPU、GPU 和 TPU)使用各种内存空间,这些空间在容量与延迟/带宽之间进行权衡。对于 Pallas 而言,我们通常关心寄存器、SRAM、DRAM 以及可能的网络通信。

  • 寄存器 是物理上最接近处理器的内存,通常在对值进行任何计算之前,必须将其直接加载到寄存器中。

  • SRAM(在 GPU 上也称为共享内存/L1 和 L2 缓存,或在 TPU 上称为 VMEM)也相当靠近处理器,但容量比寄存器大。现代机器学习加速器上的 SRAM 容量通常在 10-100MB 范围内(TPU v5p 包含 96MB VMEM,H100 GPU 包含约 30MB L1 缓存和 50MB L2 缓存)。可以合理地预期访问 SRAM 的延迟大约比访问寄存器长 10 倍。

  • DRAM(也称为 HBM)的容量远大于 SRAM,通常在现代机器学习加速器上为 10-100GB 范围。然而,与 SRAM 相比,访问延迟大约长 10 倍。

  • 当单个设备的 DRAM 容量不足以处理大型工作负载,或者我们希望利用并行计算时,网络通信变得至关重要。本教程不介绍分布式流水线,但有关编写跨多个设备的流水线,请参阅 分布式 TPU 内核 指南。

memory_hierarchy

为了对位于 HBM 中的值 X 和 Y 执行计算,我们需要:

  1. 将值 x 和 y 复制到 SRAM。

  2. 将值从 SRAM 加载到寄存器。

  3. 执行计算并将结果存储在寄存器中。

  4. 将输出寄存器中的值存储到 SRAM。

  5. 将 SRAM 中的输出值复制回 HBM。

让我们来实现一个执行此操作的 Pallas 函数!

# Note: This is a TPU example.

def add_matrices_kernel(x_sram_ref, y_sram_ref, z_sram_ref):
  # Load x and y from SRAM into registers
  x_regs = x_sram_ref[:, :]
  y_regs = y_sram_ref[:, :]
  # Execute a vectorized add
  z_regs = x_regs + y_regs
  # Store the output values in registers back into SRAM
  z_sram_ref[:, :] = z_regs


def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
  # pallas_call will first allocate scratch buffers for `x` and `y` in SRAM.
  # It will then copy `x` and `y` from HBM into SRAM.
  z = pl.pallas_call(
      add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
  )(x, y)
  # pallas_call will also copy the output from SRAM back into HBM.
  return z


x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

我们编写了两个函数:add_matrices_kerneladd_matrices

add_matrices_kernel 使用存在于 SRAM 中的 Refs。加载 SRAM Ref 会产生一个存在于寄存器中的值。寄存器中的值表现得像 jax.Arrays,我们可以对它们使用 jnpjax.lax 操作来生成存在于寄存器中的新值。当我们生成想要返回的值时,我们将其存储在输出 SRAM Ref 中。

add_matrices 函数作用于 jax.Array 并返回一个 jax.Array。在其中,我们将 x 和 y 传递给 pallas_call。 pallas_call 负责将 x 和 y 复制到 SRAM,并分配内核操作的 SRAM 缓冲区(包括分配 z_vmem_ref,即输出 SRAM 缓冲区)。内核函数运行完成后,pallas_call 还会将 z_vmem_ref 中的值复制回 HBM,从而得到一个输出 jax.Array

Pallas 暴露了对 SRAM 等低级内存空间的访问,但编写高性能内核需要更仔细地利用各种内存空间。例如,我们需要同时考虑:

  • 内存容量。 SRAM 很小!如果我们的数组太大,上述内核将无法工作,因为我们无法将输入装入 SRAM。作为参考,一个 f32[2048, 2048] 数组为 16MiB,因此上述内核只能处理中等大小的数组。

  • 内存带宽。与大多数计算指令相比,与 HBM 和 SRAM 之间的复制耗时很长。上面的 add_matrices 函数很可能花费更多时间在 HBM 和 SRAM 之间的复制上,而不是实际执行加法本身。

牢记这两个限制,我们将不得不重新考虑我们的策略,以从加速器中获得性能。

流水线基础#

我们如何利用内存层次结构中各种内存类型的优势,并能够在操作存储在 HBM 中的大数组的同时,仍然利用快速的 SRAM 进行计算?流水线是一种非常通用的编程模式,可以让我们做到这一点,但它需要将问题分解为可以并行重叠的更小子问题。

流水线的第一步是将我们的问题划分为可以容纳在 SRAM 中的更小的子问题。例如,一个元素级操作可以通过一次操作源数组的一个切片来轻松转换,从而产生以下 3 个步骤(也称为阶段):

  1. copy_in:将切片 A[i] 从 HBM 复制到 SRAM X

  2. compute:将 X 加载到寄存器,计算结果,并存储在 SRAM Y 中。

  3. copy_out:将结果 Y 复制回 HBM A[i]

请注意,步骤 1-3 之间存在数据依赖,我们无法轻易地重叠它们,因为我们需要先完成步骤 (1) 才能开始步骤 (2),以此类推。然而,跨子问题多次调用的之间没有数据依赖——也就是说,我们可以在执行块 A[i+1] 的步骤 (1) 的同时,执行块 A[i] 的步骤 (2) 和块 A[i-1] 的步骤 (3)。

pipelining_example

上图描绘了理想化的流水线程序如何在时间上进行调度。关键的见解是,在内核的大部分时间里,复制操作与计算操作并行执行,这意味着我们可以理想地用计算“隐藏”HBM/SRAM 之间传输的成本,并尽可能地使处理器保持忙碌状态。

初始启动时间和最终结束时间称为“气泡”,此时只有一部分阶段被执行,同时流水线正在“填充”或“排空”。大部分时间花在流水线的“稳态”阶段,其中每个流水线阶段跨子问题的不同迭代并行执行。虽然对于更通用的流水线方法,目标是实现 N 路并行(其中 N 是阶段数),但对于内核流水线,我们通常受内存带宽或处理速度的瓶颈。因此,我们通过内核流水线的目标通常是实现处理器 FLOP/s 的完全利用,这意味着在任何时候都有一个 compute 块处于活动状态。在上图中,计算块在 8 个时间槽中有 6 个是活动的,假设我们在每个计算时间槽中都完全利用了处理器,那么我们将实现 75% 的处理器利用率。

推导双缓冲流水线#

现在让我们看看如何在伪代码中实现流水线。考虑以下元素级程序,我们使用 copy_in 指令从 HBM(A[i])加载值,将结果加 1,然后使用 copy_out 将结果写回 HBM。

for i in range(N):
  copy_in(A[i], X)
  Y = X + 1
  copy_out(Y, A[i])

这种方法的问题在于 copy_incopy_out 通常是阻塞操作。因此,我们被迫在 GPU/TPU 空闲时等待复制完成,然后执行计算,而内存则空闲。我们希望做的是异步“预取”循环下一次迭代所需的输入值,同时执行当前循环的计算,以便计算和内存通信同时发生。

为了能够理解我们将要进行的的代码转换,让我们将循环展开 N=4 次,并将复制指令分解为单独的 copy_startcopy_wait 操作,以便能够表达异步性。

  # Itr 1
  copy_in_start(A[0], X)
  copy_in_wait(X)
  Y = X + 1
  copy_out_start(Y, A[0])
  copy_out_wait(Y)

  # Itr 2
  copy_in_start(A[1], X)
  copy_in_wait(X)
  Y = X + 1
  copy_out_start(Y, A[1])
  copy_out_wait(Y)

  # Itr 3
  copy_in_start(A[2], X)
  copy_in_wait(X)
  Y = X + 1
  copy_out_start(Y, A[2])
  copy_out_wait(Y)

  # Itr 4
  copy_in_start(A[3], X)
  copy_in_wait(X)
  Y = X + 1
  copy_out_start(Y, A[3])
  copy_out_wait(Y)

循环展开后,流水线转换仅涉及尽可能早地发出 copy_start 指令,并在需要该值之前尽可能晚地 copy_wait 值。然而,在当前循环状态下,X 存在一个虚假的数据依赖——我们不能同时对 X 进行异步复制并将其用于计算,否则可能会出现竞争条件。因此,我们可以使用多缓冲区技术,为每个输入 X 和每个输出 Y 保留 2 个缓冲区。有了 2 个缓冲区,我们可以将 copy_in_start 推迟一个迭代(有 3 个缓冲区可以推迟 2 个迭代,依此类推),并重写我们的循环如下:

  # Prologue
  copy_in_start(A[0], X[0])
  
  # Itr 1
  copy_in_start(A[1], X[1])
  copy_in_wait(X[0])
  Y[0] = X[0] + 1
  copy_out_start(Y[0], A[0])
  copy_out_wait(Y[0])

  # Itr 2 - Steady state
  copy_in_start(A[2], X[0])
  copy_in_wait(X[1])
  Y[1] = X[1] + 1
  copy_out_start(Y[1], A[1])
  copy_out_wait(Y[1])

  # Itr 3 - Steady state
  copy_in_start(A[3], X[1])
  copy_in_wait(X[0])
  Y[0] = X[0] + 1
  copy_out_start(Y[0], A[2])
  copy_out_wait(Y[0])

  # Itr 4 - No copy-in
  copy_in_wait(X[1])
  Y[1] = X[1] + 1
  copy_out_start(Y[1], A[3])
  copy_out_wait(Y[1])

接下来,我们可以在下一个循环迭代中写 Y 之前,将 copy_out_wait 推迟到尽可能晚。

  # Prologue
  copy_in_start(A[0], X[0])
  
  # Itr 1
  copy_in_start(A[1], X[1])
  copy_in_wait(X[0])
  Y[0] = X[0] + 1
  copy_out_start(Y[0], A[0])

  # Itr 2 - Steady state
  copy_in_start(A[2], X[0])
  copy_in_wait(X[1])
  Y[1] = X[1] + 1
  copy_out_start(Y[1], A[1])
  copy_out_wait(Y[0])

  # Itr 3 - Steady state
  copy_in_start(A[3], X[1])
  copy_in_wait(X[0])
  Y[0] = X[0] + 1
  copy_out_start(Y[0], A[2])
  copy_out_wait(Y[1])

  # Itr 4 - No copy-in
  copy_in_wait(X[1])
  Y[1] = X[1] + 1
  copy_out_start(Y[1], A[3])
  copy_out_wait(Y[0])

  # Epilogue
  copy_out_wait(Y[1])

最后,将循环重新折叠成一个 for 循环,我们得到以下流水线循环:

# Prologue
copy_in_start(A[0], X[0])

# Main loop
for i in range(N):
  cur_slot = i % 2
  next_slot = (i + 1) % 2

  if i+1 < N:
    copy_in_start(A[i+1], X[next_slot])
  
  copy_in_wait(X[cur_slot])
  Y[cur_slot] = X[cur_slot] + 1
  copy_out_start(Y[cur_slot], A[i])

  if i > 0:
    copy_out_wait(Y[next_slot])

# Epilogue
copy_out_wait(Y[1])

如果我们想将此循环推广以处理更广泛的计算,请注意,我们基本上需要为流水线指定 3 条信息:

  • grid (网格),即 for 循环的边界,它指定了要计算的子问题的数量。在我们的例子中,我们有一个大小为 (N,) 的一维网格。

  • kernel (内核),即一旦输入加载到 SRAM 中所发生的实际计算。在我们的例子中,我们执行了一个元素级加法 Y = X + 1

  • data_slices (数据切片),它将子问题映射到 HBM 缓冲区中相应的切片。在我们的例子中,数据切片是恒等函数 lambda i: i

通过允许用户指定这些信息,我们可以按照此模式编写各种程序。

def double_buffered_pipeline(
    grid: tuple[int, ...],
    kernel: Callable,
    in_slices: Callable,
    out_slices: Callable):
  # Prologue
  copy_in_start(in_hbm[in_slices(0)], in_sram[0])

  # Main loop
  grid_size = prod(grid)
  for i in range(grid_size):
    cur_slot = i % 2
    next_slot = (i + 1) % 2
    if (i + 1) < grid_size:
      copy_in_start(in_hbm[in_slices(i+1)], in_sram[next_slot])
    copy_in_wait(in_sram[cur_slot])

    kernel(in_sram[cur_slot], out_ram[cur_slot])

    copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)])
    if i > 0:
      copy_out_wait(out_sram[next_slot])

  # Epilogue
  last_slot = (grid_size - 1) % 2
  copy_out_wait(out_sram[last_slot])

既然我们已经看到了如何手动实现流水线循环,现在让我们看看如何使用 Pallas API。

Pallas 流水线 API#

Pallas 提供了一个流水线 API,它抽象了维护多个缓冲区和将异步通信与计算重叠的样板代码。此 API 的基础知识已在 Pallas Quickstart 中介绍,因此我们在此处简要回顾一下 API 以确保完整性,并讨论使用流水线时出现的一些注意事项。

Grid (网格)#

程序 **grid** 是一个整数元组,指定子问题的数量作为一个数组。流水线的结构可以被解释为嵌套的 for 循环,其中每个循环的边界。

# For grid (N, M, K)
for n in range (N):
  for m in range(M):
    for k in range(K):
      kernel()

内核将被调用总共 prod(grid) 次。有关更多详细信息,请参阅 grid 和 blockspecs

BlockSpecs (块规格)#

BlockSpec 指定在每个子问题中复制到内核的数据的大小和切片。pl.BlockSpec 的基本构造函数涉及指定 block_shape(数据切片的大小)和 index_map(一个函数,接收当前子程序的程序 ID 并输出源缓冲区中的索引)。块索引指定每次迭代要复制的块,假设源缓冲区已按 block_shape 的形状分割成块。memory_space 参数指定将输入复制到的内存空间——默认情况下是 SRAM。

pl.BlockSpec(
  block_shape: tuple[int, ...],
  index_map: Callable,
  memory_space: pl.MemorySpace
)

内核的每个输入和输出都应该有一个 BlockSpec。有关更多详细信息,请参阅 grid 和 blockspecs

Kernel (内核)#

内核函数指定每个子问题执行的计算。内核函数不应返回任何输出,而所有输出都应写入传递到内核的输出缓冲区中。默认情况下,所有输入和输出缓冲区都是 SRAM 缓冲区(除非用户通过在相应的 BlockSpec 上指定 memory_space 来覆盖该行为)。

def kernel(*input_buffers, *output_buffers):
  # ... perform compute
  # ... store result into output buffers

当前子程序的索引可以在内核中使用 pl.program_id(grid_axis: int) 进行查询。

Pallas Call#

函数 pl.pallas_call 是 Pallas 的主要入口点,当提供了 grid 和 BlockSpecs 时,它会执行流水线执行。它的签名如下:

def pallas_call(
  kernel,
  grid: tuple[int, ...],
  in_specs: Sequence[PyTree[BlockSpec]],
  out_specs: PyTree[BlockSpec],
  out_shape: PyTree[jax.ShapeDtypeStruct],
) -> Callable:

pallas_call 将返回一个可调用函数,当使用输入值调用时,它将返回与 out_shape 相同的形状的输出。

in_specsout_specsout_shape 是各自元素类型的 PyTrees。传递给内核的 in_specs 和输入缓冲区的 PyTrees 应该匹配,out_specsout_shape 的 PyTrees 也应该匹配。

示例 - 元素级内核的重访#

让我们重访教程开头 add_matrices_kernel,但使用流水线。我们将添加两个形状为 f32[4096, 4096] 且存在于 HBM 中的输入数组。作为子问题,我们将输入分割成 block_shape=(512, 512) 的块,并在内核中一次只将两个块相加。由于加法是元素级的,每个 index_map 都相同,并在 i, j 次迭代中选择 i, j 块。

# Note: This is a TPU example.

total_shape = (4096, 4096)
block_shape = (512, 512)

def add_matrices_pipelined_kernel(x_ref, y_ref, o_ref):
  o_ref[...] = x_ref[...] + y_ref[...]

def add_matrices_pipelined(x: jax.Array, y: jax.Array):
  return pl.pallas_call(
    add_matrices_pipelined_kernel,
    grid=tuple(total // block for (total, block) in zip(total_shape, block_shape)),
    in_specs=[
      pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),
      pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j))
    ],
    out_specs=pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),
    out_shape=jax.ShapeDtypeStruct(total_shape, dtype=jnp.float32),
  )(x, y)

x = jax.random.uniform(jax.random.key(0), total_shape, dtype=jnp.float32)
y = jax.random.uniform(jax.random.key(1), total_shape, dtype=jnp.float32)
result = add_matrices_pipelined(x, y)
np.testing.assert_array_equal(
    result, x + y
)

事实证明,使用此 API,编写流水线内核的代码行数并不比编写我们原始的简单加法内核多多少!

参数化内核#

在我们的内核中参数化块形状是很常见的。块大小可能是优化 Pallas 内核性能最重要的参数!它们让我们能够控制流水线(例如,选择更小的块会增加流水线循环的迭代次数,其中每次迭代的工作量更少)。让我们编写一个执行此操作的函数:

def add_matrices_pipelined_param(
    x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
  m, n = x.shape
  block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(m // bm, n // bn),
  )(x, y)

np.testing.assert_array_equal(
    add_matrices_pipelined_param(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_param(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_param(x, y, bm=512, bn=512), x + y
)

注意事项#

虽然流水线提供了对简单地在循环中调用内核函数的思维模型的近似,但由于使用了未完全对用户隐藏的中间缓冲区,因此出现了一些注意事项,这些注意事项可能导致细微的错误。

重新审视缓冲区#

总的来说,一个经验法则是:传递到内核函数的输入缓冲区应被解释为只读,输出缓冲区为只写

写入输入和读取输出在大多数情况下都会导致不正确的结果。这是因为传递给内核的 SRAM 缓冲区仅包含底层 HBM 缓冲区中数据的副本。如果更新了输入 SRAM 缓冲区,更新的结果将永远不会写回 HBM;如果更新了输出缓冲区,其更新值永远不会被读入 SRAM。这个问题类似于使用缓存时遇到的陈旧性问题。

在两种情况下,缓冲区同时支持读写:累加(稍后讨论)和通过向 pallas_call 传递 input_output_aliases 参数将一对输入和输出缓冲区标记为输入-输出别名。

归约和累加#

归约/累加应仅在网格的最后一个(最内层)维度上执行,并且应先手动初始化缓冲区。

归约是流水线支持读取和写入输出缓冲区的少数情况之一,但其原因很微妙。Pallas 流水线发射器执行一项优化,即如果两个连续迭代之间的数据切片相同,流水线将不会在该缓冲区上发出 copy_in/copy_out。这意味着在前一个迭代中使用的同一个 SRAM 缓冲区将在下一个迭代中再次传递到内核,因此对输出缓冲区发出的任何写入将在下一个迭代中可见。一旦数据切片发生变化,最终累加的 SRAM 缓冲区将被写回 HBM。这也是为什么归约必须在网格的最后一个维度上执行——我们希望在输出缓冲区位于最内层循环的 SRAM 中时完成所有累加,然后将其写回 HBM,并且不再接触该输出块。

作为具体示例,让我们考虑对 (8, 1024, 1024) 数组沿第一个轴进行归约,将其归约到一个 (1024, 1024) 数组。

x = jnp.ones((8, 1024, 1024))
jnp.sum(x, axis=0)
Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)

要使用 pallas_call 实现此目的,我们可以使用大小为 (8,) 的网格,并在每个迭代 i 中将 x[i] 加载到 SRAM。然后,我们可以将 x[i] 添加到一个输出 SRAM 缓冲区。让我们先 naively 实现这一点。

# Note: This is a TPU example.

# Warning: this implementation is incorrect!
def incorrect_sum_kernel(x_ref, o_ref):
  o_ref[...] += x_ref[...]

def incorrect_sum(x: jax.Array,
              block_size: tuple[int, ...] = (256, 256)) -> jax.Array:
  reduction_size, *out_shape = x.shape
  grid = (reduction_size, *(out // blk for out, blk in zip(out_shape, block_size)))
  return pl.pallas_call(
      incorrect_sum_kernel,
      grid=grid,
      # None in `block_shape` means we pick a size of 1 and squeeze it away
      in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (i, j, k))],
      out_specs=pl.BlockSpec(block_size, lambda i, j, k: (j, k)),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
  )(x)

result = incorrect_sum(x)
print(result)
[[65. 65. 65. ... 66. 66. 66.]
 [65. 65. 65. ... 66. 66. 66.]
 [65. 65. 65. ... 66. 66. 66.]
 ...
 [71. 71. 71. ... 72. 72. 72.]
 [71. 71. 71. ... 72. 72. 72.]
 [71. 71. 71. ... 72. 72. 72.]]

结果完全错误!

此内核中有两个错误。首先,我们沿第一个网格维度而不是最后一个网格维度进行累加。其次,o_ref 最初包含垃圾值,因此我们需要在开始累加之前将其初始化为零。

修复这两个问题后,我们得到以下修正后的内核。在此新内核中,我们使用 @pl.when 创建一个条件,该条件在程序 ID 沿归约轴为 0 时检查,这表示我们正在开始累加到一个新的输出块。我们还将归约维度移到了 grid 的最后一个轴。

# Note: This is a TPU example.

def correct_sum_kernel(x_ref, o_ref):
  @pl.when(pl.program_id(2) == 0)
  def _():
    o_ref[...] = jnp.zeros_like(o_ref)
  o_ref[...] += x_ref[...]

def correct_sum(x: jax.Array,
              block_size: tuple[int, ...] = (256, 256)) -> jax.Array:
  reduction_size, *out_shape = x.shape
  # We moved the reduction to the last axis of the grid.
  grid = (*(out // blk for out, blk in zip(out_shape, block_size)), reduction_size)
  return pl.pallas_call(
      correct_sum_kernel,
      grid=grid,
      # None in `block_shape` means we pick a size of 1 and squeeze it away
      in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (k, i, j))],
      out_specs=pl.BlockSpec(block_size, lambda i, j, k: (i, j)),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
  )(x)

result = correct_sum(x)
print(result)
[[8. 8. 8. ... 8. 8. 8.]
 [8. 8. 8. ... 8. 8. 8.]
 [8. 8. 8. ... 8. 8. 8.]
 ...
 [8. 8. 8. ... 8. 8. 8.]
 [8. 8. 8. ... 8. 8. 8.]
 [8. 8. 8. ... 8. 8. 8.]]

性能分析#

流水线内核的性能如何?这个问题可能会因硬件瓶颈所在而异。我们通常关心 3 个量:

  • 内存延迟 \(α\),内存传输的最小延迟。

  • 内存带宽 \(β\),从 HBM 到 SRAM 的传输速率(字节/秒)。

  • FLOP/s \(F\),即每秒浮点运算次数,处理器每秒可以执行的计算次数。

我们将程序称为计算密集型(compute-bound),如果瓶颈是处理速度 FLOPs/s;称为内存密集型(memory-bound),如果瓶颈是带宽或延迟。通常,我们的目标是优化内核,使其成为计算密集型,这意味着我们正在利用硬件的所有可用处理能力。

假设我们正在运行一个程序,该程序每个内核迭代需要 \(X\) 字节的内存传输,并且每个迭代运行 \(Y\) 次浮点运算。 \(X\)\(Y\) 的比率取决于计算的类型——对于加法或乘法等元素级操作,两者都以相同的比例缩放。但是,对于矩阵乘法等操作,计算随问题大小呈立方增长,而内存随问题大小呈平方增长。

计算密集型模式下,运行 \(N\) 次迭代的流水线将需要 \((\alpha + X/\beta) + N (Y/F)\) 秒,其中第一项表示初始气泡的成本(如果末尾也有气泡,则乘以 2),第二项表示流水线稳态的总时间。假设 N 很大且有足够的工作来产生长流水线,运行时最主要的项是 \(F\),即加速器的处理速度。

pipelining_compute

内存密集型模式下,区分问题是延迟还是带宽很有用。如果带宽是瓶颈,那么总运行时间将为 \(\alpha + N(X / \beta)\) 秒。与延迟密集型模式不同,内存复制是串行的,因为带宽已经饱和。内存密集型通常不是理想的,因为会有处理器空闲的时间间隔,并且在大多数硬件配置中,内存带宽 \(\beta\) 比处理速度 \(F\) 慢几个数量级。

pipelining_bandwidth

如果瓶颈是延迟而不是带宽,可以通过插入额外的流水线阶段来解决问题,但需要额外的 SRAM 来存储更多缓冲区。通过足够的阶段,问题将再次成为计算密集型或带宽密集型,具体取决于我们在流水线稳态阶段首先遇到的瓶颈。然而,多阶段流水线的缺点是气泡的大小与阶段数成正比,因此重要的是要确保流水线足够长,以免气泡占据总运行时间的很大一部分。

pipelining_latency

TPU 上的 Pallas 仅支持双缓冲,因为 TPU 程序可以处理更大的块大小,而双缓冲通常足以覆盖延迟。在 GPU 上,流水线阶段的数量可以在 Triton(通过 CompilerParams)和 Mosaic GPU 后端(通过流水线发射器的参数)中指定。有关更多详细信息,请参阅特定平台的流水线文档。