软件流水线#

软件流水线是通过重叠多个异步操作(即使它们之间存在数据依赖性)来实现性能优化的重要技术。在内核编写的上下文中,最常见的流水线形式涉及将通信和内存传输与计算重叠,从而使硬件加速器永远不会在等待数据到达时停顿。因此,在本教程中,我们将只关注通信-计算流水线问题。我们将首先从概念上介绍该问题,概述用于编写流水线的 Pallas API,并介绍一些使用该 API 的实际示例。

本教程仅涵盖流水线的概念基础。有关平台特定参考资料,请参阅 TPU 或 GPU(即将推出!)特定流水线参考资料。

import jax
from jax import numpy as jnp
from jax.experimental import pallas as pl
import numpy as np

内存层级结构#

从概念上理解流水线的第一步涉及理解可用的不同形式的内存以及它们之间的权衡。大多数硬件架构(包括 CPU、GPU 和 TPU)都利用各种各样的内存空间,这些内存空间在容量与延迟/带宽之间进行权衡。对于 Pallas 的目的而言,我们通常对寄存器、SRAM、DRAM 以及可能的网络通信感兴趣

  • 寄存器 是物理上最接近处理器的内存,通常值必须直接加载到寄存器中,然后才能对其进行任何计算。

  • SRAM(也称为 GPU 上的共享内存/L1 和 L2 缓存,或 TPU 上的 VMEM)也离处理器相当近,但容量比寄存器更大。现代 ML 加速器上的 SRAM 容量通常在 10-100MB 范围内(TPU v5p 包含 96MB 的 VMEM,H100 GPU 包含约 30MB 的 L1 缓存和 50MB 的 L2)。可以合理地预期,访问 SRAM 的延迟比访问寄存器长约 10 倍。

  • DRAM(也称为 HBM)的容量比 SRAM 大得多,现代 ML 加速器的容量通常在 10-100GB 范围内。但是,与 SRAM 相比,访问延迟大约长 10 倍。

  • 当单个设备上的 DRAM 大小不足或我们想利用并行计算的优势时,网络 通信变得至关重要。在本教程中,我们不介绍分布式流水线,但有关跨多个设备编写流水线的信息,请参阅 分布式 TPU 内核 指南。

memory_hierarchy

为了对 HBM 中的值 X 和 Y 执行计算,我们需要

  1. 将值 x 和 y 复制到 SRAM 中。

  2. 将值从 SRAM 加载到寄存器中。

  3. 执行计算并将结果存储到寄存器中。

  4. 将输出寄存器中的值存储到 SRAM 中。

  5. 将 SRAM 中的输出值复制回 HBM。

让我们实现一个执行此操作的 Pallas 函数!

# Note: This is a TPU example.

def add_matrices_kernel(x_sram_ref, y_sram_ref, z_sram_ref):
  # Load x and y from SRAM into registers
  x_regs = x_sram_ref[:, :]
  y_regs = y_sram_ref[:, :]
  # Execute a vectorized add
  z_regs = x_regs + y_regs
  # Store the output values in registers back into SRAM
  z_sram_ref[:, :] = z_regs


def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
  # pallas_call will first allocate scratch buffers for `x` and `y` in SRAM.
  # It will then copy `x` and `y` from HBM into SRAM.
  z = pl.pallas_call(
      add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
  )(x, y)
  # pallas_call will also copy the output from SRAM back into HBM.
  return z


x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

我们编写了两个函数:add_matrices_kerneladd_matrices

add_matrices_kernel 使用驻留在 SRAM 中的 Refs 进行操作。从 SRAM Ref 加载会生成一个驻留在寄存器中的值。寄存器中的值行为类似于 jax.Arrays,我们可以对它们使用 jnpjax.lax 操作,以生成驻留在寄存器中的新值。当我们生成想要返回的值时,我们会将它们存储在输出 SRAM Ref 中。

add_matrices 函数作用于 jax.Arrays 并返回一个 jax.Array。在其中,我们将 xy 传递到 pallas_call。pallas_call 负责将 xy 复制到 SRAM 中,并负责分配内核操作的 SRAM 缓冲区(包括分配 z_vmem_ref,输出 SRAM 缓冲区)。内核函数运行完成后,pallas_call 还会将 z_vmem_ref 中的值复制到 HBM,从而生成一个输出 jax.Array

Pallas 允许访问较低级别的内存空间(如 SRAM),但编写高性能内核需要在利用各种内存空间时更加小心。例如,我们需要同时考虑以下两点

  • 内存容量。SRAM 很小!如果我们的数组太大,上面的内核将无法工作,因为我们无法将输入放入 SRAM 中。作为参考,一个 f32[2048, 2048] 数组为 16MiB,因此我们上面的内核无法扩展到中等大小的数组之外。

  • 内存带宽。与大多数计算指令相比,在 HBM 和 SRAM 之间复制需要很长时间。上面的 add_matrices 函数可能会花费更多时间在 HBM 和 SRAM 之间复制,而不是实际执行加法运算本身。

考虑到这两个约束,我们将不得不重新思考从加速器中获得性能的策略。

流水线基础知识#

我们如何才能利用层级结构中每种类型内存的优势,并能够在 HBM 中存储大型数组的同时,仍然利用快速 SRAM 进行计算?流水线是一种非常通用的编程模式,它将使我们能够做到这一点,但这需要将你的问题转换为可以并行重叠的更小的子问题。

流水线的第一步是将我们的问题划分为可以放入 SRAM 中的更小的子问题。例如,可以通过一次操作源数组的一个切片来轻松转换元素级操作,从而产生以下 3 个步骤(也称为阶段)

  1. copy_in:将切片 A[i] 从 HBM 复制到 SRAM X

  2. compute:将 X 加载到寄存器中,计算结果,并存储在 SRAM Y

  3. copy_out:将结果 Y 复制回 HBM A[i]

请注意,步骤 1-3 之间存在数据依赖性,由于我们需要在开始步骤 (2) 之前完成步骤 (1),依此类推,因此我们不能轻易地重叠它们。但是,子问题的多次调用之间没有数据依赖性——也就是说,我们可以在为块 A[i] 执行步骤 (2) 和为块 A[i-1] 执行步骤 (3) 的同时,为块 A[i+1] 执行步骤 (1)。

pipelining_example

上图描绘了一个理想化的流水线程序如何在时间上进行调度。关键的见解是在内核的大部分时间里,复制操作与计算操作并行执行,这意味着我们可以理想地用计算“隐藏”在 HBM/SRAM 之间传输的成本,并尽可能保持处理器繁忙。

初始启动时间和最终拆卸时间称为“气泡”,在管道被“填充”或“排空”时,只有一部分阶段正在执行。大部分时间都花在了管道的“稳态”阶段,其中每个流水线阶段都在子问题的不同迭代中并行执行。虽然对于更通用的流水线方法,目标是实现 N 路并行(其中 N 是阶段数),但对于内核流水线,我们通常会受到内存带宽或处理速度的瓶颈。因此,我们内核流水线的目的通常是实现处理器 FLOP/秒的完全利用率,这意味着在任何时间点,始终都有一个 compute 块处于活动状态。在上图中,计算块在 6/8 个时隙中处于活动状态,假设我们在每个计算时隙中都充分利用了处理器,我们将实现处理器 75% 的利用率。

推导双缓冲流水线#

现在让我们看看如何在伪代码中实现流水线。考虑以下元素级程序,其中我们使用 copy_in 指令从 HBM (A[i]) 加载值,将结果加 1,然后使用 copy_out 将结果存储回 HBM

for i in range(N):
  copy_in(A[i], X)
  Y = X + 1
  copy_out(Y, A[i])

这种方法的问题在于 copy_incopy_out 通常是阻塞操作。因此,我们被迫等待复制完成,而 GPU/TPU 处于空闲状态,然后在内存空闲时执行计算。我们想要做的是“预取”循环的下一次迭代所需的输入值,并在为当前循环执行计算时异步进行,以便计算和内存通信同时发生。

为了理解我们将要进行的代码转换,让我们将循环展开 N=4,并将复制指令分解为单独的 copy_startcopy_wait 操作,以便能够表达异步性

  # Itr 1
  copy_in_start(A[0], X)
  copy_in_wait(X)
  Y = X + 1
  copy_out_start(Y, A[0])
  copy_out_wait(Y)

  # Itr 2
  copy_in_start(A[1], X)
  copy_in_wait(X)
  Y = X + 1
  copy_out_start(Y, A[1])
  copy_out_wait(Y)

  # Itr 3
  copy_in_start(A[2], X)
  copy_in_wait(X)
  Y = X + 1
  copy_out_start(Y, A[2])
  copy_out_wait(Y)

  # Itr 4
  copy_in_start(A[3], X)
  copy_in_wait(X)
  Y = X + 1
  copy_out_start(Y, A[3])
  copy_out_wait(Y)

一旦循环展开,流水线转换就只涉及尽早发出 copy_start 指令,并尽可能晚地 copy_wait 值(就在我们需要该值之前)。但是,在循环的当前状态下,通过 X 存在伪数据依赖性——我们不能在将 X 用于计算的同时对其执行异步复制,否则可能会出现竞争条件。因此,我们可以使用多缓冲技术,其中我们为每个输入 X 和每个输出 Y 保留 2 个缓冲区。使用 2 个缓冲区,我们可以将 copy_in_start 提前一个迭代(使用 3 个缓冲区,您可以提前 2 个迭代,依此类推),我们将循环重写如下

  # Prologue
  copy_in_start(A[0], X[0])
  
  # Itr 1
  copy_in_start(A[1], X[1])
  copy_in_wait(X[0])
  Y[0] = X[0] + 1
  copy_out_start(Y[0], A[0])
  copy_out_wait(Y[0])

  # Itr 2 - Steady state
  copy_in_start(A[2], X[0])
  copy_in_wait(X[1])
  Y[1] = X[1] + 1
  copy_out_start(Y[1], A[1])
  copy_out_wait(Y[1])

  # Itr 3 - Steady state
  copy_in_start(A[3], X[1])
  copy_in_wait(X[0])
  Y[0] = X[0] + 1
  copy_out_start(Y[0], A[2])
  copy_out_wait(Y[0])

  # Itr 4 - No copy-in
  copy_in_wait(X[1])
  Y[1] = X[1] + 1
  copy_out_start(Y[1], A[2])
  copy_out_wait(Y[1])

接下来,我们可以尽可能晚地推送 copy_out_wait,就在我们需要在后续循环迭代中写入 Y 之前。

  # Prologue
  copy_in_start(A[0], X[0])
  
  # Itr 1
  copy_in_start(A[1], X[1])
  copy_in_wait(X[0])
  Y[0] = X[0] + 1
  copy_out_start(Y[0], A[0])

  # Itr 2 - Steady state
  copy_in_start(A[2], X[0])
  copy_in_wait(X[1])
  Y[1] = X[1] + 1
  copy_out_start(Y[1], A[1])
  copy_out_wait(Y[0])

  # Itr 3 - Steady state
  copy_in_start(A[3], X[1])
  copy_in_wait(X[0])
  Y[0] = X[0] + 1
  copy_out_start(Y[0], A[2])
  copy_out_wait(Y[1])

  # Itr 4 - No copy-in
  copy_in_wait(X[1])
  Y[1] = X[1] + 1
  copy_out_start(Y[1], A[2])
  copy_out_wait(Y[0])

  # Epilogue
  copy_out_wait(Y[1])

最后,将我们的循环重新卷回 for 循环,我们得到以下流水线循环

# Prologue
copy_in_start(A[0], X[0])

# Main loop
for i in range(N):
  cur_slot = i % 2
  next_slot = (i + 1) % 2

  if i < N:
    copy_in_start(A[i+1], X[next_slot])
  
  copy_in_wait(X[cur_slot])
  Y[cur_slot] = X[cur_slot] + 1
  copy_out_start(Y[cur_slot], A[i])

  if i > 0:
    copy_out_wait(Y[next_slot])

# Epilogue
copy_out_wait(Y[1])

如果我们想概括这个循环以处理更广泛的计算,请注意我们基本上需要为流水线指定 3 条信息

  • 网格,或 for 循环的边界,用于指定要计算的子问题的数量。在我们的示例中,我们有一个大小为 (N,) 的一维网格。

  • 内核,或输入加载到 SRAM 后发生的实际计算。在我们的示例中,我们执行了元素级加法 Y = X + 1

  • data_slices,它将子问题映射到 HBM 缓冲区中的相应切片。在我们的示例中,数据切片是恒等函数 lambda i: i

通过允许用户指定这些信息,我们可以编写各种遵循此模式的程序

def double_buffered_pipeline(
    grid: tuple[int, ...],
    kernel: Callable,
    in_slices: Callable,
    out_slices: Callable):
  # Prologue
  copy_in_start(in_hbm[in_slices(0)], in_sram[0])

  # Main loop
  grid_size = prod(grid)
  for i in range(grid_size):
    cur_slot = i % 2
    next_slot = (i + 1) % 2
    if i < grid_size:
      copy_in_start(in_hbm[data_slices(i+1)], in_sram[next_slot])
    copy_in_wait(in_sram[cur_slot])

    kernel(inputs, outputs)

    copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)])
    if i > 0:
      copy_out_wait(out_sram[next_slot])

  # Epilogue
  copy_out_wait(out_sram[1])

既然我们已经了解了如何手动实现流水线循环,让我们看看如何使用 Pallas API。

Pallas 流水线 API#

Pallas 提供了一个流水线 API,它抽象出了维护多个缓冲区以及将异步通信与计算重叠的样板代码。此 API 的基础知识已在 快速入门 中介绍,因此为了完整起见,我们将在此处简要介绍该 API,并讨论流水线使用中出现的一些尖锐边缘。

网格#

程序网格是一个整数元组,指定子问题的数量为数组。流水线的结构可以解释为嵌套的 for 循环,其中包含每个循环的边界。

# For grid (N, M, K)
for n in range (N):
  for m in range(M):
    for k in range(K):
      kernel()

内核将被调用总共 prod(grid) 次。有关更多详细信息,请参阅 网格和 blockspecs

BlockSpecs#

BlockSpec 指定在每个子问题上复制到内核的数据的大小和切片。pl.BlockSpec 的基本构造函数包括指定 block_shape(数据切片的大小)和 index_map(一个函数,它接收当前子问题的程序 ID,并输出到源缓冲区的分块索引)。分块索引指定在每次迭代时要复制哪个块,假设源缓冲区已被划分为形状为 block_shape 的块。memory_space 参数指定将输入复制到哪个内存空间——默认情况下,这将是 SRAM。

pl.BlockSpec(
  block_shape: tuple[int, ...],
  index_map: Callable,
  memory_space: pl.MemorySpace
)

内核的每个输入和每个输出都应有一个 BlockSpec。有关更多详细信息,请参阅 网格和 blockspecs

内核#

内核函数指定要在每个子问题上执行的计算。内核函数不应返回任何输出,而是所有输出都应写入传递到内核的输出缓冲区中。默认情况下,所有输入和输出缓冲区都是 SRAM 缓冲区(除非用户通过在相应的 BlockSpec 上指定 memory_space 来覆盖该行为)。

def kernel(*input_buffers, *output_buffers):
  # ... perform compute
  # ... store result into output buffers

当前子问题的索引可以在内核内部使用 pl.program_id(grid_axis: int) 查询。

Pallas 调用#

pl.pallas_call 函数是 Pallas 的主要入口点,并在提供网格和 BlockSpecs 时执行流水线式执行。它具有以下签名

def pallas_call(
  kernel,
  grid: tuple[int, ...],
  in_specs: Sequence[PyTree[BlockSpec]],
  out_specs: PyTree[BlockSpec],
  out_shape: PyTree[jax.ShapeDtypeStruct],
) -> Callable:

pallas_call 将返回一个可调用函数,该函数在用输入值调用时,将返回与 out_shape 形状相同的输出。

in_specsout_specsout_shape 是它们各自元素类型的 PyTrees。 in_specs 的 PyTrees 和提供给内核的输入缓冲区应该匹配,并且 out_specsout_shape 的 PyTrees 也应该匹配。

示例 - 重温元素级内核#

让我们重温教程开始时的初始 add_matrices_kernel,除了使用流水线。我们将添加两个形状为 f32[4096, 4096] 且位于 HBM 中的输入数组。作为子问题,我们将输入分解为 block_shape=(512, 512) 块,并且在内核中一次只将两个块相加。由于加法是元素级的,因此每个 index_map 都是相同的,并在第 i, j 次迭代中选择出第 i, j 个块。

# Note: This is a TPU example.

total_shape = (4096, 4096)
block_shape = (512, 512)

def add_matrices_pipelined_kernel(x_ref, y_ref, o_ref):
  o_ref[...] = x_ref[...] + y_ref[...]

def add_matrices_pipelined(x: jax.Array, y: jax.Array):
  return pl.pallas_call(
    add_matrices_pipelined_kernel,
    grid=tuple(total // block for (total, block) in zip(total_shape, block_shape)),
    in_specs=[
      pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),
      pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j))
    ],
    out_specs=pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),
    out_shape=jax.ShapeDtypeStruct(total_shape, dtype=jnp.float32),
  )(x, y)

x = jax.random.uniform(jax.random.key(0), total_shape, dtype=jnp.float32)
y = jax.random.uniform(jax.random.key(1), total_shape, dtype=jnp.float32)
result = add_matrices_pipelined(x, y)
np.testing.assert_array_equal(
    result, x + y
)

事实证明,使用此 API,编写流水线内核的代码行数与编写我们最初的朴素加法内核的代码行数相差无几!

参数化内核#

在我们的内核中参数化块形状是很常见的。块大小可能是优化 Pallas 内核性能时最重要的参数!它们让我们控制流水线(例如,选择较小的块会为我们的流水线循环添加更多迭代,其中每次迭代的工作量较少)。让我们编写一个执行此操作的函数

def add_matrices_pipelined_param(
    x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
  m, n = x.shape
  block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(m // bm, n // bn),
  )(x, y)

np.testing.assert_array_equal(
    add_matrices_pipelined_param(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_param(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_param(x, y, bm=512, bn=512), x + y
)

尖锐的边缘#

虽然流水线提供了对循环中简单调用内核函数的心理模型的近似,但由于中间缓冲区的使用并非完全对用户隐藏,因此会出现许多尖锐的边缘,并可能导致细微的错误。

缓冲区重访#

一般来说,要遵循的一个好的经验法则是:传递到内核函数中的输入缓冲区应解释为只读,而输出缓冲区是只写的

在大多数情况下,写入输入和从输出读取会导致不正确。这是因为传递给内核的 SRAM 缓冲区仅包含底层 HBM 缓冲区中包含的数据的副本。如果更新了输入 SRAM 缓冲区,则更新后的结果永远不会写回 HBM,如果更新了输出缓冲区,则其更新后的值永远不会读入 SRAM。此问题类似于通常在使用缓存时遇到的陈旧性问题。

有两种情况下缓冲区支持读取和写入 - 累积(接下来讨论),以及通过将 input_output_aliases 参数传递给 pallas_call 将一对输入和输出缓冲区标记为输入-输出别名。

归约和累积#

归约/累积应仅在网格的最后(最内层)维度上执行,并且应首先手动初始化缓冲区。

归约是流水线支持对输出缓冲区进行读取和写入的少数情况之一,但其工作原理很微妙。Pallas 流水线发射器执行一项优化,即如果两个连续迭代之间的数据切片相同,则流水线不会在该缓冲区上发出 copy_in/copy_out。这意味着在后续迭代中,将在内核中再次传递先前迭代中使用的相同 SRAM 缓冲区,因此对输出缓冲区发出的任何写入都将在下一次迭代中可见。一旦网格索引更改,最终累积的 SRAM 缓冲区将被写入 HBM。这也是为什么归约必须在网格的最后维度上执行的原因 – 我们希望在最内层循环中当输出缓冲区在 SRAM 中时完成所有累积,然后将其写入 HBM 并且不再接触该输出块。

作为一个具体的例子,让我们考虑执行以下计算,将 (8, 1024, 1024) 数组沿第一个轴归约到 (1024, 1024) 数组。

x = jnp.ones((8, 1024, 1024))
jnp.sum(x, axis=0)
Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)

为了使用 pallas_call 执行此操作,我们可以使用大小为 (8,) 的网格,并在每次迭代 i 中将 x[i] 加载到 SRAM 中。然后我们可以将 x[i] 添加到输出 SRAM 缓冲区。让我们先天真地实现它。

# Note: This is a TPU example.

# Warning: this implementation is incorrect!
def incorrect_sum_kernel(x_ref, o_ref):
  o_ref[...] += x_ref[...]

def incorrect_sum(x: jax.Array,
              block_size: tuple[int, ...] = (256, 256)) -> jax.Array:
  reduction_size, *out_shape = x.shape
  grid = (reduction_size, *(out // blk for out, blk in zip(out_shape, block_size)))
  return pl.pallas_call(
      incorrect_sum_kernel,
      grid=grid,
      # None in `block_shape` means we pick a size of 1 and squeeze it away
      in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (i, j, k))],
      out_specs=pl.BlockSpec(block_size, lambda i, j, k: (j, k)),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
  )(x)

result = incorrect_sum(x)
print(result)
[[65. 65. 65. ... 66. 66. 66.]
 [65. 65. 65. ... 66. 66. 66.]
 [65. 65. 65. ... 66. 66. 66.]
 ...
 [71. 71. 71. ... 72. 72. 72.]
 [71. 71. 71. ... 72. 72. 72.]
 [71. 71. 71. ... 72. 72. 72.]]

此结果完全错误!

此内核内部存在两个错误。首先,我们沿第一个网格维度而不是最后一个网格维度进行累积。其次,o_ref 最初包含垃圾值,因此我们需要在开始累积之前将其初始化为零。

在修复这两个问题后,我们获得了以下更正后的内核。在这个新内核中,我们使用 @pl.when 创建一个条件,以检查程序 ID 何时沿归约轴为 0,表明我们开始累积到新的输出块中。我们还将归约维度移动到了 grid 的最后一个轴。

# Note: This is a TPU example.

def correct_sum_kernel(x_ref, o_ref):
  @pl.when(pl.program_id(2) == 0)
  def _():
    o_ref[...] = jnp.zeros_like(o_ref)
  o_ref[...] += x_ref[...]

def correct_sum(x: jax.Array,
              block_size: tuple[int, ...] = (256, 256)) -> jax.Array:
  reduction_size, *out_shape = x.shape
  # We moved the reduction to the last axis of the grid.
  grid = (*(out // blk for out, blk in zip(out_shape, block_size)), reduction_size)
  return pl.pallas_call(
      correct_sum_kernel,
      grid=grid,
      # None in `block_shape` means we pick a size of 1 and squeeze it away
      in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (k, i, j))],
      out_specs=pl.BlockSpec(block_size, lambda i, j, k: (i, j)),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
  )(x)

result = correct_sum(x)
print(result)
[[8. 8. 8. ... 8. 8. 8.]
 [8. 8. 8. ... 8. 8. 8.]
 [8. 8. 8. ... 8. 8. 8.]
 ...
 [8. 8. 8. ... 8. 8. 8.]
 [8. 8. 8. ... 8. 8. 8.]
 [8. 8. 8. ... 8. 8. 8.]]

分析性能#

流水线内核的性能如何?这个问题可能因硬件的瓶颈所在而异。我们通常对 3 个量感兴趣

  • 内存延迟 \(α\),内存传输的最小延迟。

  • 内存带宽 \(β\),我们可以从 HBM 传输到 SRAM 的速率,单位为字节/秒。

  • FLOP/s \(F\),或每秒浮点运算次数,处理器每秒可以执行的计算次数。

如果处理速度 FLOPs/s 是瓶颈,我们将程序称为计算密集型,如果带宽或延迟是瓶颈,则称为内存密集型。一般来说,我们的目标是优化内核,使其成为计算密集型,这意味着我们正在利用硬件的所有可用处理能力。

假设我们正在运行一个程序,每个内核迭代需要 \(X\) 字节的内存传输,并且每个迭代运行 \(Y\) 次浮点运算。 \(X\)\(Y\) 的比率取决于计算类型 – 对于元素级运算(如加法或乘法),它们都将同等缩放。但是,对于矩阵乘法等运算,计算量随问题规模立方增长,而内存量随问题规模平方增长。

计算密集型状态下,运行 \(N\) 次迭代的流水线将花费 \((\alpha + X/\beta) + N (Y/F)\) 秒,其中第一项表示初始气泡的成本(如果末尾也有气泡,则乘以系数 2),第二项表示流水线稳态的总时间。假设 N 很大并且有足够的工作来产生长流水线,则运行时中的主导项是 \(F\),即加速器的处理速度。

pipelining_compute

内存密集型状态下,识别问题是延迟还是带宽很有用。如果带宽是瓶颈,则总运行时间将花费 \(\alpha + X / \beta\) 秒。与延迟密集型状态相比,内存复制是串行发生的,因为带宽已经饱和。内存密集型通常不是理想的,因为在某些时间段内处理器将处于空闲状态,并且在大多数硬件配置中,内存带宽 \(\beta\) 比处理速度 \(F\) 慢几个数量级。

pipelining_bandwidth

如果瓶颈专门是延迟而不是带宽,则可以通过插入额外的流水线阶段来解决问题,但代价是需要额外的 SRAM 来存储更多缓冲区。在有足够的阶段的情况下,问题将再次变为计算密集型或延迟密集型,具体取决于我们在流水线稳态阶段首先遇到的瓶颈。然而,多级流水线的缺点是气泡的大小与阶段数成正比,因此重要的是要确保流水线足够长,以使气泡不会占用总运行时间的很大一部分。

pipelining_latency

TPU 上的 Pallas 仅支持双缓冲,因为 TPU 程序可以处理更大的块大小,并且双缓冲通常足以覆盖延迟。在 GPU 上,可以在 Triton 中(通过 TritonCompilerParams)和 Mosaic GPU 后端中(通过流水线发射器的参数)指定流水线阶段的数量。有关更多详细信息,请参阅特定于平台的流水线文档。