软件流水线#
软件流水线是一种重要的性能优化技术,通过重叠多个异步操作来提高效率,即使这些操作之间存在数据依赖。在编写内核的上下文中,流水线最常见的形式涉及将通信和内存传输与计算重叠,从而使硬件加速器在等待数据到达时不会停顿。因此,本教程将仅关注通信-计算流水线问题。我们将首先从概念上介绍该问题,概述用于编写流水线的 Pallas API,并介绍使用该 API 的一些实际示例。
本教程仅涵盖流水线的概念基础。有关特定平台的参考,请参阅 TPU 流水线 或 Mosaic GPU 流水线。
import jax
from jax import numpy as jnp
from jax.experimental import pallas as pl
import numpy as np
内存层次结构#
从概念上理解流水线的第一步是理解可用的不同内存形式及其权衡。大多数硬件架构(包括 CPU、GPU 和 TPU)使用各种内存空间,这些空间在容量与延迟/带宽之间进行权衡。对于 Pallas 而言,我们通常关心寄存器、SRAM、DRAM 以及可能的网络通信。
寄存器 是物理上最接近处理器的内存,通常在对值进行任何计算之前,必须将其直接加载到寄存器中。
SRAM(在 GPU 上也称为共享内存/L1 和 L2 缓存,或在 TPU 上称为 VMEM)也相当靠近处理器,但容量比寄存器大。现代机器学习加速器上的 SRAM 容量通常在 10-100MB 范围内(TPU v5p 包含 96MB VMEM,H100 GPU 包含约 30MB L1 缓存和 50MB L2 缓存)。可以合理地预期访问 SRAM 的延迟大约比访问寄存器长 10 倍。
DRAM(也称为 HBM)的容量远大于 SRAM,通常在现代机器学习加速器上为 10-100GB 范围。然而,与 SRAM 相比,访问延迟大约长 10 倍。
当单个设备的 DRAM 容量不足以处理大型工作负载,或者我们希望利用并行计算时,网络通信变得至关重要。本教程不介绍分布式流水线,但有关编写跨多个设备的流水线,请参阅 分布式 TPU 内核 指南。
为了对位于 HBM 中的值 X 和 Y 执行计算,我们需要:
将值 x 和 y 复制到 SRAM。
将值从 SRAM 加载到寄存器。
执行计算并将结果存储在寄存器中。
将输出寄存器中的值存储到 SRAM。
将 SRAM 中的输出值复制回 HBM。
让我们来实现一个执行此操作的 Pallas 函数!
# Note: This is a TPU example.
def add_matrices_kernel(x_sram_ref, y_sram_ref, z_sram_ref):
# Load x and y from SRAM into registers
x_regs = x_sram_ref[:, :]
y_regs = y_sram_ref[:, :]
# Execute a vectorized add
z_regs = x_regs + y_regs
# Store the output values in registers back into SRAM
z_sram_ref[:, :] = z_regs
def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
# pallas_call will first allocate scratch buffers for `x` and `y` in SRAM.
# It will then copy `x` and `y` from HBM into SRAM.
z = pl.pallas_call(
add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
# pallas_call will also copy the output from SRAM back into HBM.
return z
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
我们编写了两个函数:add_matrices_kernel
和 add_matrices
。
add_matrices_kernel
使用存在于 SRAM 中的 Refs
。加载 SRAM Ref 会产生一个存在于寄存器中的值。寄存器中的值表现得像 jax.Arrays,我们可以对它们使用 jnp
和 jax.lax
操作来生成存在于寄存器中的新值。当我们生成想要返回的值时,我们将其存储在输出 SRAM Ref
中。
add_matrices
函数作用于 jax.Array
并返回一个 jax.Array
。在其中,我们将 x 和 y 传递给 pallas_call。 pallas_call
负责将 x 和 y 复制到 SRAM,并分配内核操作的 SRAM 缓冲区(包括分配 z_vmem_ref
,即输出 SRAM 缓冲区)。内核函数运行完成后,pallas_call
还会将 z_vmem_ref
中的值复制回 HBM,从而得到一个输出 jax.Array
。
Pallas 暴露了对 SRAM 等低级内存空间的访问,但编写高性能内核需要更仔细地利用各种内存空间。例如,我们需要同时考虑:
内存容量。 SRAM 很小!如果我们的数组太大,上述内核将无法工作,因为我们无法将输入装入 SRAM。作为参考,一个
f32[2048, 2048]
数组为 16MiB,因此上述内核只能处理中等大小的数组。内存带宽。与大多数计算指令相比,与 HBM 和 SRAM 之间的复制耗时很长。上面的
add_matrices
函数很可能花费更多时间在 HBM 和 SRAM 之间的复制上,而不是实际执行加法本身。
牢记这两个限制,我们将不得不重新考虑我们的策略,以从加速器中获得性能。
流水线基础#
我们如何利用内存层次结构中各种内存类型的优势,并能够在操作存储在 HBM 中的大数组的同时,仍然利用快速的 SRAM 进行计算?流水线是一种非常通用的编程模式,可以让我们做到这一点,但它需要将问题分解为可以并行重叠的更小子问题。
流水线的第一步是将我们的问题划分为可以容纳在 SRAM 中的更小的子问题。例如,一个元素级操作可以通过一次操作源数组的一个切片来轻松转换,从而产生以下 3 个步骤(也称为阶段):
copy_in:将切片
A[i]
从 HBM 复制到 SRAMX
。compute:将
X
加载到寄存器,计算结果,并存储在 SRAMY
中。copy_out:将结果
Y
复制回 HBMA[i]
。
请注意,步骤 1-3 之间存在数据依赖,我们无法轻易地重叠它们,因为我们需要先完成步骤 (1) 才能开始步骤 (2),以此类推。然而,跨子问题多次调用的之间没有数据依赖——也就是说,我们可以在执行块 A[i+1]
的步骤 (1) 的同时,执行块 A[i]
的步骤 (2) 和块 A[i-1]
的步骤 (3)。
上图描绘了理想化的流水线程序如何在时间上进行调度。关键的见解是,在内核的大部分时间里,复制操作与计算操作并行执行,这意味着我们可以理想地用计算“隐藏”HBM/SRAM 之间传输的成本,并尽可能地使处理器保持忙碌状态。
初始启动时间和最终结束时间称为“气泡”,此时只有一部分阶段被执行,同时流水线正在“填充”或“排空”。大部分时间花在流水线的“稳态”阶段,其中每个流水线阶段跨子问题的不同迭代并行执行。虽然对于更通用的流水线方法,目标是实现 N 路并行(其中 N 是阶段数),但对于内核流水线,我们通常受内存带宽或处理速度的瓶颈。因此,我们通过内核流水线的目标通常是实现处理器 FLOP/s 的完全利用,这意味着在任何时候都有一个 compute
块处于活动状态。在上图中,计算块在 8 个时间槽中有 6 个是活动的,假设我们在每个计算时间槽中都完全利用了处理器,那么我们将实现 75% 的处理器利用率。
推导双缓冲流水线#
现在让我们看看如何在伪代码中实现流水线。考虑以下元素级程序,我们使用 copy_in
指令从 HBM(A[i]
)加载值,将结果加 1,然后使用 copy_out
将结果写回 HBM。
for i in range(N): copy_in(A[i], X) Y = X + 1 copy_out(Y, A[i])
这种方法的问题在于 copy_in
和 copy_out
通常是阻塞操作。因此,我们被迫在 GPU/TPU 空闲时等待复制完成,然后执行计算,而内存则空闲。我们希望做的是异步“预取”循环下一次迭代所需的输入值,同时执行当前循环的计算,以便计算和内存通信同时发生。
为了能够理解我们将要进行的的代码转换,让我们将循环展开 N=4 次,并将复制指令分解为单独的 copy_start
和 copy_wait
操作,以便能够表达异步性。
# Itr 1 copy_in_start(A[0], X) copy_in_wait(X) Y = X + 1 copy_out_start(Y, A[0]) copy_out_wait(Y) # Itr 2 copy_in_start(A[1], X) copy_in_wait(X) Y = X + 1 copy_out_start(Y, A[1]) copy_out_wait(Y) # Itr 3 copy_in_start(A[2], X) copy_in_wait(X) Y = X + 1 copy_out_start(Y, A[2]) copy_out_wait(Y) # Itr 4 copy_in_start(A[3], X) copy_in_wait(X) Y = X + 1 copy_out_start(Y, A[3]) copy_out_wait(Y)
循环展开后,流水线转换仅涉及尽可能早地发出 copy_start
指令,并在需要该值之前尽可能晚地 copy_wait
值。然而,在当前循环状态下,X 存在一个虚假的数据依赖——我们不能同时对 X 进行异步复制并将其用于计算,否则可能会出现竞争条件。因此,我们可以使用多缓冲区技术,为每个输入 X 和每个输出 Y 保留 2 个缓冲区。有了 2 个缓冲区,我们可以将 copy_in_start
推迟一个迭代(有 3 个缓冲区可以推迟 2 个迭代,依此类推),并重写我们的循环如下:
# Prologue copy_in_start(A[0], X[0]) # Itr 1 copy_in_start(A[1], X[1]) copy_in_wait(X[0]) Y[0] = X[0] + 1 copy_out_start(Y[0], A[0]) copy_out_wait(Y[0]) # Itr 2 - Steady state copy_in_start(A[2], X[0]) copy_in_wait(X[1]) Y[1] = X[1] + 1 copy_out_start(Y[1], A[1]) copy_out_wait(Y[1]) # Itr 3 - Steady state copy_in_start(A[3], X[1]) copy_in_wait(X[0]) Y[0] = X[0] + 1 copy_out_start(Y[0], A[2]) copy_out_wait(Y[0]) # Itr 4 - No copy-in copy_in_wait(X[1]) Y[1] = X[1] + 1 copy_out_start(Y[1], A[3]) copy_out_wait(Y[1])
接下来,我们可以在下一个循环迭代中写 Y 之前,将 copy_out_wait
推迟到尽可能晚。
# Prologue copy_in_start(A[0], X[0]) # Itr 1 copy_in_start(A[1], X[1]) copy_in_wait(X[0]) Y[0] = X[0] + 1 copy_out_start(Y[0], A[0]) # Itr 2 - Steady state copy_in_start(A[2], X[0]) copy_in_wait(X[1]) Y[1] = X[1] + 1 copy_out_start(Y[1], A[1]) copy_out_wait(Y[0]) # Itr 3 - Steady state copy_in_start(A[3], X[1]) copy_in_wait(X[0]) Y[0] = X[0] + 1 copy_out_start(Y[0], A[2]) copy_out_wait(Y[1]) # Itr 4 - No copy-in copy_in_wait(X[1]) Y[1] = X[1] + 1 copy_out_start(Y[1], A[3]) copy_out_wait(Y[0]) # Epilogue copy_out_wait(Y[1])
最后,将循环重新折叠成一个 for 循环,我们得到以下流水线循环:
# Prologue
copy_in_start(A[0], X[0])
# Main loop
for i in range(N):
cur_slot = i % 2
next_slot = (i + 1) % 2
if i+1 < N:
copy_in_start(A[i+1], X[next_slot])
copy_in_wait(X[cur_slot])
Y[cur_slot] = X[cur_slot] + 1
copy_out_start(Y[cur_slot], A[i])
if i > 0:
copy_out_wait(Y[next_slot])
# Epilogue
copy_out_wait(Y[1])
如果我们想将此循环推广以处理更广泛的计算,请注意,我们基本上需要为流水线指定 3 条信息:
grid (网格),即 for 循环的边界,它指定了要计算的子问题的数量。在我们的例子中,我们有一个大小为
(N,)
的一维网格。kernel (内核),即一旦输入加载到 SRAM 中所发生的实际计算。在我们的例子中,我们执行了一个元素级加法
Y = X + 1
。data_slices (数据切片),它将子问题映射到 HBM 缓冲区中相应的切片。在我们的例子中,数据切片是恒等函数
lambda i: i
。
通过允许用户指定这些信息,我们可以按照此模式编写各种程序。
def double_buffered_pipeline(
grid: tuple[int, ...],
kernel: Callable,
in_slices: Callable,
out_slices: Callable):
# Prologue
copy_in_start(in_hbm[in_slices(0)], in_sram[0])
# Main loop
grid_size = prod(grid)
for i in range(grid_size):
cur_slot = i % 2
next_slot = (i + 1) % 2
if (i + 1) < grid_size:
copy_in_start(in_hbm[in_slices(i+1)], in_sram[next_slot])
copy_in_wait(in_sram[cur_slot])
kernel(in_sram[cur_slot], out_ram[cur_slot])
copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)])
if i > 0:
copy_out_wait(out_sram[next_slot])
# Epilogue
last_slot = (grid_size - 1) % 2
copy_out_wait(out_sram[last_slot])
既然我们已经看到了如何手动实现流水线循环,现在让我们看看如何使用 Pallas API。
Pallas 流水线 API#
Pallas 提供了一个流水线 API,它抽象了维护多个缓冲区和将异步通信与计算重叠的样板代码。此 API 的基础知识已在 Pallas Quickstart 中介绍,因此我们在此处简要回顾一下 API 以确保完整性,并讨论使用流水线时出现的一些注意事项。
Grid (网格)#
程序 **grid** 是一个整数元组,指定子问题的数量作为一个数组。流水线的结构可以被解释为嵌套的 for 循环,其中每个循环的边界。
# For grid (N, M, K)
for n in range (N):
for m in range(M):
for k in range(K):
kernel()
内核将被调用总共 prod(grid)
次。有关更多详细信息,请参阅 grid 和 blockspecs。
BlockSpecs (块规格)#
BlockSpec 指定在每个子问题中复制到内核的数据的大小和切片。pl.BlockSpec
的基本构造函数涉及指定 block_shape
(数据切片的大小)和 index_map
(一个函数,接收当前子程序的程序 ID 并输出源缓冲区中的块索引)。块索引指定每次迭代要复制的块,假设源缓冲区已按 block_shape
的形状分割成块。memory_space
参数指定将输入复制到的内存空间——默认情况下是 SRAM。
pl.BlockSpec(
block_shape: tuple[int, ...],
index_map: Callable,
memory_space: pl.MemorySpace
)
内核的每个输入和输出都应该有一个 BlockSpec。有关更多详细信息,请参阅 grid 和 blockspecs。
Kernel (内核)#
内核函数指定每个子问题执行的计算。内核函数不应返回任何输出,而所有输出都应写入传递到内核的输出缓冲区中。默认情况下,所有输入和输出缓冲区都是 SRAM 缓冲区(除非用户通过在相应的 BlockSpec
上指定 memory_space
来覆盖该行为)。
def kernel(*input_buffers, *output_buffers):
# ... perform compute
# ... store result into output buffers
当前子程序的索引可以在内核中使用 pl.program_id(grid_axis: int)
进行查询。
Pallas Call#
函数 pl.pallas_call
是 Pallas 的主要入口点,当提供了 grid 和 BlockSpecs 时,它会执行流水线执行。它的签名如下:
def pallas_call(
kernel,
grid: tuple[int, ...],
in_specs: Sequence[PyTree[BlockSpec]],
out_specs: PyTree[BlockSpec],
out_shape: PyTree[jax.ShapeDtypeStruct],
) -> Callable:
pallas_call
将返回一个可调用函数,当使用输入值调用时,它将返回与 out_shape
相同的形状的输出。
in_specs
、out_specs
和 out_shape
是各自元素类型的 PyTrees。传递给内核的 in_specs
和输入缓冲区的 PyTrees 应该匹配,out_specs
和 out_shape
的 PyTrees 也应该匹配。
示例 - 元素级内核的重访#
让我们重访教程开头 add_matrices_kernel
,但使用流水线。我们将添加两个形状为 f32[4096, 4096]
且存在于 HBM 中的输入数组。作为子问题,我们将输入分割成 block_shape=(512, 512)
的块,并在内核中一次只将两个块相加。由于加法是元素级的,每个 index_map
都相同,并在 i, j
次迭代中选择 i, j
块。
# Note: This is a TPU example.
total_shape = (4096, 4096)
block_shape = (512, 512)
def add_matrices_pipelined_kernel(x_ref, y_ref, o_ref):
o_ref[...] = x_ref[...] + y_ref[...]
def add_matrices_pipelined(x: jax.Array, y: jax.Array):
return pl.pallas_call(
add_matrices_pipelined_kernel,
grid=tuple(total // block for (total, block) in zip(total_shape, block_shape)),
in_specs=[
pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),
pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j))
],
out_specs=pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),
out_shape=jax.ShapeDtypeStruct(total_shape, dtype=jnp.float32),
)(x, y)
x = jax.random.uniform(jax.random.key(0), total_shape, dtype=jnp.float32)
y = jax.random.uniform(jax.random.key(1), total_shape, dtype=jnp.float32)
result = add_matrices_pipelined(x, y)
np.testing.assert_array_equal(
result, x + y
)
事实证明,使用此 API,编写流水线内核的代码行数并不比编写我们原始的简单加法内核多多少!
参数化内核#
在我们的内核中参数化块形状是很常见的。块大小可能是优化 Pallas 内核性能最重要的参数!它们让我们能够控制流水线(例如,选择更小的块会增加流水线循环的迭代次数,其中每次迭代的工作量更少)。让我们编写一个执行此操作的函数:
def add_matrices_pipelined_param(
x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
m, n = x.shape
block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(m // bm, n // bn),
)(x, y)
np.testing.assert_array_equal(
add_matrices_pipelined_param(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
add_matrices_pipelined_param(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
add_matrices_pipelined_param(x, y, bm=512, bn=512), x + y
)
性能分析#
流水线内核的性能如何?这个问题可能会因硬件瓶颈所在而异。我们通常关心 3 个量:
内存延迟 \(α\),内存传输的最小延迟。
内存带宽 \(β\),从 HBM 到 SRAM 的传输速率(字节/秒)。
FLOP/s \(F\),即每秒浮点运算次数,处理器每秒可以执行的计算次数。
我们将程序称为计算密集型(compute-bound),如果瓶颈是处理速度 FLOPs/s;称为内存密集型(memory-bound),如果瓶颈是带宽或延迟。通常,我们的目标是优化内核,使其成为计算密集型,这意味着我们正在利用硬件的所有可用处理能力。
假设我们正在运行一个程序,该程序每个内核迭代需要 \(X\) 字节的内存传输,并且每个迭代运行 \(Y\) 次浮点运算。 \(X\) 与 \(Y\) 的比率取决于计算的类型——对于加法或乘法等元素级操作,两者都以相同的比例缩放。但是,对于矩阵乘法等操作,计算随问题大小呈立方增长,而内存随问题大小呈平方增长。
在计算密集型模式下,运行 \(N\) 次迭代的流水线将需要 \((\alpha + X/\beta) + N (Y/F)\) 秒,其中第一项表示初始气泡的成本(如果末尾也有气泡,则乘以 2),第二项表示流水线稳态的总时间。假设 N 很大且有足够的工作来产生长流水线,运行时最主要的项是 \(F\),即加速器的处理速度。
在内存密集型模式下,区分问题是延迟还是带宽很有用。如果带宽是瓶颈,那么总运行时间将为 \(\alpha + N(X / \beta)\) 秒。与延迟密集型模式不同,内存复制是串行的,因为带宽已经饱和。内存密集型通常不是理想的,因为会有处理器空闲的时间间隔,并且在大多数硬件配置中,内存带宽 \(\beta\) 比处理速度 \(F\) 慢几个数量级。
如果瓶颈是延迟而不是带宽,可以通过插入额外的流水线阶段来解决问题,但需要额外的 SRAM 来存储更多缓冲区。通过足够的阶段,问题将再次成为计算密集型或带宽密集型,具体取决于我们在流水线稳态阶段首先遇到的瓶颈。然而,多阶段流水线的缺点是气泡的大小与阶段数成正比,因此重要的是要确保流水线足够长,以免气泡占据总运行时间的很大一部分。
TPU 上的 Pallas 仅支持双缓冲,因为 TPU 程序可以处理更大的块大小,而双缓冲通常足以覆盖延迟。在 GPU 上,流水线阶段的数量可以在 Triton(通过 CompilerParams
)和 Mosaic GPU 后端(通过流水线发射器的参数)中指定。有关更多详细信息,请参阅特定平台的流水线文档。