标量预取与块稀疏计算#
在本教程中,我们将介绍 Pallas 中块稀疏计算的基础知识。稀疏计算是编写 Pallas 自定义内核而非仅使用 JAX/XLA 的主要原因,因为在 XLA 中很难表达执行动态计算量的程序,这是由于静态数组形状的限制。在本教程中,我们将学习如何使用 Pallas 的标量预取功能来编写可以动态跳过计算和内存块的块稀疏内核。
import functools
import timeit
import numpy as np
import jax
from jax import numpy as jnp
from jax import lax
from jax.experimental import checkify
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
print("Running on", jax.devices()[0].device_kind)
Running on TPU v5 lite
带标量预取的动态块索引#
我们将利用 Pallas 的“标量预取”功能来实现稀疏内核。标量预取允许您将少量数据传入 SMEM(“标量内存”),这些数据在流水线启动前加载(“预取”)。由于这些数据在流水线启动前加载,因此可用于每个 BlockSpec 的 index_map 中,从而允许您执行数据依赖的索引计算。本教程的主要目标是介绍利用此功能的常见编程模式。
要使用标量预取,请使用 pltpu.PrefetchScalarGridSpec 替换标准的 pl.GridSpec。
class PrefetchScalarGridSpec:
def __init__(self,
num_scalar_prefetch: int,
grid: tuple[int, ...],
in_specs: PyTree[BlockSpec],
out_specs: PyTree[BlockSpec],
scratch_shapes: tuple[MemorySpace, ...]):
...
num_scalar_prefetch 参数指示标量预取值的数量。当设置为非零值时,它会更改内核和索引映射的调用签名,以期望额外的预取值。传入 index_map 和内核的 Ref 都会在 SMEM 中分配,并且不会被分成块,因为它们没有定义 BlockSpec。此外,index_map 和内核的参数顺序始终固定,如下所述。
每个
BlockSpec的index_map现在期望预取Ref出现在网格索引之后。
def index_map(*grid_indices, *prefetch_refs):
...
用户定义的内核期望预取
Ref出现在输入Ref之前。此外,暂存Ref出现在输出Ref之后。
def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs):
...
使用
pallas_call调用新内核时,pallas_call返回的函数也期望标量预取参数出现在输入之前,例如:
kernel = pl.pallas_call(...)
result = kernel(*prefetch_args, *input_args)
示例:带标量预取的块动态切片#
让我们从一个演示如何使用标量预取功能的简单示例开始。我们将实现一个块对齐的动态切片内核,该内核根据用户指定的索引简单地从较大的数组中提取一个块。
在内核外部,我们计算要提取的块索引为:
block_idx = (start[0] // size[0], start[1] // size[1])我们将
block_idx作为标量预取参数传入pallas_call。在我们的索引映射中,我们使用块索引通过返回
(block_idx[0], block_idx[1])来选择相应的块。
当然,此内核的限制在于我们的切片大小必须适合内核块(由 VMEM 大小限制),并且我们只能从大小对齐的索引开始。更高级的内核会将内核块大小与切片大小解耦,并允许非对齐的起始索引。
def dynamic_slice_kernel(indices, x_ref, o_ref):
del indices
o_ref[...] = x_ref[...]
@checkify.checkify
@functools.partial(jax.jit, static_argnums=(2,))
def block_dynamic_slice(x, starts, sizes):
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
grid=(1, 1),
in_specs=[pl.BlockSpec(
sizes,
lambda i, j, block_idx: (block_idx[0], block_idx[1]))],
out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)),
)
kernel = pl.pallas_call(
dynamic_slice_kernel,
grid_spec=grid_spec,
out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype),
)
# Checkify inserts a runtime assert that starts are divisible by block size.
checkify.check(starts[0] % sizes[0] == 0, "Starts must be divisible by size.")
checkify.check(starts[1] % sizes[1] == 0, "Starts must be divisible by size.")
block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]])
return kernel(block_idx, x)
shape = (512, 512)
x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape)
err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128))
err.throw()
ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128))
diff = jnp.max(jnp.abs(result - ref))
print("Error |result - lax.dynamic_slice| =", diff)
Error |result - lax.dynamic_slice| = 0
稀疏内核:表示稀疏数据#
在深入实现稀疏内核之前,让我们先回顾一下稀疏矩阵的表示方法。虽然有几种流行的存储稀疏矩阵的格式,但我们将遵循坐标列表格式(COO)的块状变体,其中我们将矩阵存储为 (block_index, block_data) 对的列表。列表中未明确存储的所有块都假定为零,这意味着如果我们矩阵中有许多零块,我们可以节省大量内存。
下图演示了如何将一个 4x4 的稠密矩阵(左)转换为块大小为 2x2 的块-COO 格式(右)。请注意,在稀疏格式中,我们可以避免显式存储由所有零元素组成的右上块。
我们将使用以下辅助函数来采样一个块稀疏矩阵。它返回一个用于检查我们结果的稠密矩阵,以及每个轴的块数据和索引列表。
def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32):
"""Returns a sampled matrix and its block-sparse representation.
Args:
key: RNG Key.
M: Major array dimension.
N: Minor array dimension.
blk_M: Block size along M dimension.
blk_N: Block size along N dimension.
p: Probability that a block will be non-zero.
dtype: dtype of the sampled matrix.
Returns:
dense_mat: A (M, N) dense sampled array.
block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing
the non-zero blocks of the matrix.
indices_i: A (num_blocks,) array of block indices for the first axis.
indices_j: A (num_blocks,) array of block indices for the second axis.
"""
mask_key, blocks_key = jax.random.split(key)
num_blocks = (M // blk_M, N // blk_N)
# We first sample a block mask, denoting which blocks are nonzero.
block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks)
num_blocks = jnp.sum(block_mask)
indices = jnp.where(block_mask)
# For each non-zero block, we sample a block of random values.
block_data = jax.random.uniform(blocks_key,
shape=(num_blocks, blk_M, blk_N),
dtype=dtype)
# For checking purposes, create the dense version of the sparse matrix.
dense_mat = jnp.zeros((M, N), dtype=dtype)
for blk in range(num_blocks):
idx_i = indices[0][blk]
idx_j = indices[1][blk]
slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M)
slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N)
dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk])
return dense_mat, block_data, indices[0], indices[1]
示例:稀疏 @ 稠密矩阵乘法#
在我们的第一个示例中,我们将稀疏 LHS 矩阵与稠密 RHS 矩阵相乘以生成稠密输出。
我们将内核网格结构化为 2 个循环 - 外循环遍历 RHS/输出的列,内循环遍历 LHS 的稀疏块。在每次内循环迭代中,我们从 LHS 加载一个块,并通过收缩维度(K)的块索引在 RHS 中查找相应的块。我们将两个块相乘并累加到正确的输出块中。一次外循环迭代将计算整个列的结果,如下图所示。
重要的是,我们将块索引按行分组(例如 [0, 0, 1, 2, 3, 3]),然后再将它们传入内核,原因有两个。首先,在我们的内核中,我们需要知道何时在输出 Ref 中初始将累加器清零,如果行索引是分组的,则很容易做到。其次,Pallas 的流水线逻辑不允许我们在非连续迭代上重新访问输出 Ref 中的块,因此我们需要在连续的内核迭代中完成所有累加到输出块的操作。这是因为流水线发射器会意识到我们在连续迭代中加载相同的输出块,并将其保存在 VMEM 中。当我们更改输出块时,Pallas 将最终将输出存储到 HBM 中,并假定我们不再使用它。未能连续访问输出块将导致错误的值,即使内核在逻辑上是正确的。
M = N = K = 16384
blk_M = blk_N = blk_K = 512
def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
x_ref, y_ref, _, o_ref, # Kernel inputs.
accum_scratch,
):
"""A DSD (Dense = Sparse @ Dense) matmul kernel."""
del idxs_k_ref
blk_idx = pl.program_id(1)
is_start = blk_idx == 0
changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
@pl.when(is_start | changed_blocks)
def _():
accum_scratch[...] = jnp.zeros_like(accum_scratch)
accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32)
next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)])
is_end = blk_idx == (num_blocks - 1)
@pl.when(is_end | next_block_change)
def _():
o_ref[...] = accum_scratch[...].astype(o_ref.dtype)
def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
del j, blk_idxs_i, blk_idxs_k
return (blk_idx, 0, 0)
def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
del blk_idxs_i
return (blk_idxs_k[blk_idx], j)
def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
del blk_idxs_k
return (blk_idxs_i[blk_idx], j)
(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat(
jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16)
num_blocks = X_blocks.shape[0]
Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
zeros = jnp.zeros((M, N), dtype=jnp.bfloat16)
out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
# Note that while num_blocks is static here, Pallas does support
# dynamic grid sizes.
grid=(N // blk_N, num_blocks),
in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
pl.BlockSpec((blk_K, blk_N), y_map),
# Placeholder for a zeros-array used by input_output_aliases.
pl.BlockSpec((blk_M, blk_N), o_map),
],
out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
)
kernel = pl.pallas_call(
dsd_kernel,
grid_spec=grid_spec,
out_shape=out_shape,
# We use input-output aliases to zero-out o_ref for blocks that we never
# visit. By passing in an array of zeros we avoid having o_ref start with
# uninitialized values.
input_output_aliases={4: 0}, # Map zeros to o_ref.
)
args = (indices_i, indices_k, X_blocks, Y, zeros)
result = kernel(*args)
ref = X_dense @ Y
diff = jnp.abs(ref - result)
print('mean |result - ref|:', jnp.mean(diff))
mean |result - ref|: 0
我们可以快速进行基准测试,将我们的稀疏内核与 JAX 中的稠密矩阵乘法进行性能比较。在 TPU v5e 芯片上,与理论上的稀疏因子带来的 10 倍速度提升相比,此内核实现了约 6 倍的速度提升。
这里有一些提高性能的主要技巧,主要集中在减少 HBM/VMEM 之间的通信开销上。
使用
dtype=jnp.bfloat16对于性能至关重要,因为它将内存带宽减半。使用更大的块大小也有帮助,因为矩阵乘法是一个
\(O(N^3)\)的计算和\(O(N^2)\)的内存操作。随着\(N\)的增大,内核将变得计算受限。然而,在实践中,这有一个反驳观点,即较小的块大小也能使数据更稀疏,因此这是一个应仔细选择的参数。
# Benchmark Sparse Pallas kernel vs reference JAX implementation
def benchmark(f, ntrials: int = 100):
def run(*args, **kwargs):
# Compile function first
jax.block_until_ready(f(*args, **kwargs))
# Time function
result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
number=ntrials)
time = result / ntrials
return time
return run
n_trials = 100
pallas_impl = lambda *args: kernel(*args)
time = benchmark(pallas_impl, n_trials)(indices_i, indices_k, X_blocks, Y, zeros)
print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
ref_impl = jax.jit(lambda x, y: x @ y)
time = benchmark(ref_impl, n_trials)(X_dense, Y)
print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
Sparse Kernel: 8.136 ms (avg over 100 trials)
Reference: 46.953 ms (avg over 100 trials)
稠密数据上的稀疏访问模式#
在我们之前的示例中,我们考虑了数据本身稀疏的情况。这在内核结构中体现为内核网格的一个维度,该维度是动态的,并遍历非零块的数量(num_blocks)。
当底层数据是稠密的,但我们希望对其进行稀疏计算时,会出现第二种有用的编程模式。在这种情况下,我们的内核网格将是稠密的,但我们希望跳过网格中的一些块,这由块稀疏掩码指示。这种编程模式通常在使用许多机器学习应用中的掩码时出现,例如自注意力中的因果掩码或局部掩码。在这些情况下,我们可以完全跳过掩码为零的块中的计算。这种编程模式的示例可以在 jax/experimental/pallas/ops/tpu 中的 Splash Attention 和 Grouped Matrix Multiplication 内核中找到,或者在 PyTorch 的 FlexAttention 中找到。
处理稠密数据上的稀疏访问模式的主要性能考虑是与流水线的交互。在任何给定的内核迭代中,Pallas 流水线发射器将尝试通过调用每个 BlockSpec 在网格的下一个迭代中的 index_map 来预取下一块数据。但是,如果我们的计算是稀疏的,我们可能会跳过下一个块的计算,因此我们需要一种方法来告诉流水线,开始获取 *我们没有跳过的下一个块*。为了做到这一点,我们需要构建 *预取映射*,其中包含每个内核输入的下一个非跳过块数据的索引。下图说明了如何为以 COO 格式存储的块稀疏掩码构建预取映射。
左:稀疏访问模式,其中蓝色表示我们需要计算的非零掩码块。右:预取映射,其中数组的每个元素包含下一个非零块数据的索引。
一旦构建了预取映射,我们就可以将该映射作为标量预取参数传入,并在 BlockSpec 的 index_map 函数中查询它。
def mask_index_map(prefetch_map, i, j, ...):
next_nonzero_block = prefetch_map[i, j]
return (next_nonzero_block, 0, 0)
我们可以为内核的其他输入构建类似的索引映射。对于稠密输入,您很可能需要构建指向网格中下一个非零块索引的预取映射。我们的下一个示例将提供一个使用这些预取映射的示例。
示例:带块稀疏输出掩码的稠密 @ 稠密矩阵乘法#
在我们的下一个示例中,我们将涵盖稠密矩阵乘法与稀疏输出掩码的融合,使用预取映射来提高流水线性能。我们将使用掩码选择性地跳过计算零输出块,从而节省计算成本。
由于我们将使用稀疏掩码,我们将首先实现一个函数,该函数将以稠密格式存储的 N x M 掩码转换为块稀疏格式。我们还需要计算预取映射,以帮助流水线发射器知道下一个要获取的块。总而言之,我们的 sparsify_mask 函数计算:
一个形状为
(num_N_blocks, num_M_blocks)的block_mask,指示一个块是全零(值为0)还是包含非零元素(值为1)。如果block_mask的值为 0,则我们可以跳过在内核中计算该块。一个形状为
(num_N_blocks, num_M_blocks)的prefetch_mask数组,包含下一个非零块的mask_data的索引。一个形状为
(num_N_blocks, num_M_blocks)的prefetch_i数组,包含下一个非掩码i索引的掩码。一个形状为
(num_N_blocks, num_M_blocks)的prefetch_j数组,包含下一个非掩码j索引的掩码。一个形状为
(num_blocks, blk_N, blk_M)的mask_data数组,其中包含掩码的非零块的数据。
def sparsify_mask(mask: jax.Array,
block_shape: tuple[int, int]):
"""Preprocesses a mask into a sparse representation.
Args:
mask: A boolean array of shape [M, N]
block_shape: The size of a single block.
Returns:
block_mask: A block_shape array of booleans indicating whether a block
is all-zeros (0) or contains non-zero elements (1).
prefetch_mask: A block_shape array of integers indicating the index of the
next non-zero block.
mask_data: A (num_blocks, block_shape) array containing
the data for non-zero blocks of the mask.
"""
M, N = mask.shape
bm, bn = block_shape
block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)
mask_types_finder = []
mask_data = []
next_mask_type_idx = 0
prefetch_mask = jnp.zeros_like(block_mask)
next_i = (M // bm) - 1
next_j = (N // bn) - 1
prefetch_i = jnp.zeros_like(block_mask)
prefetch_j = jnp.zeros_like(block_mask)
for i in range(M // bm, -1, -1):
for j in range(N // bn, -1, -1):
mask_block = mask[i * bm :(i + 1) * bm,
j * bn :(j + 1) * bn]
is_nonzero = jnp.any(mask_block)
if is_nonzero:
try:
type_index = mask_types_finder.index(str(mask_block))
except ValueError:
type_index = len(mask_types_finder)
mask_types_finder.append(str(mask_block))
mask_data.append(mask_block)
next_mask_type_idx = type_index
next_i = i
next_j = j
else:
type_index = -1
block_mask = block_mask.at[i, j].set(is_nonzero)
prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)
prefetch_i = prefetch_i.at[i, j].set(next_i)
prefetch_j = prefetch_j.at[i, j].set(next_j)
return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data)
在内核结构方面,我们使用与前面教程中介绍的标准矩阵乘法内核相同的网格模式,具有 N、M 和 K 维度的 3 个循环。在内核本身内部,我们首先检查 block_mask 以查看当前输出块的掩码是否为全零。如果掩码为全零,我们可以跳过计算并转到下一个块;否则,我们需要计算矩阵乘法,然后对结果进行掩码。
M = N = K = 16384
blk_M = blk_N = 512
blk_K = 1024
def sparse_mask_matmul(
block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs.
x_ref, y_ref, mask_ref, o_ref, # Kernel inputs.
accum_scratch
):
del prefetch_mask, prefetch_i, prefetch_j
i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2)
should_compute = block_mask_ref[i, j] != 0
@pl.when(k == 0)
def _():
o_ref[...] = jnp.zeros_like(o_ref)
accum_scratch[...] = jnp.zeros_like(accum_scratch[...])
# We only compute the output for blocks with non-zero masks.
# Otherwise we skip the computation entirely.
@pl.when(should_compute)
def _():
result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32)
accum_scratch[...] += result
@pl.when(k == pl.num_programs(2) - 1)
def _():
o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype)
X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16)
Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
mask = jnp.ones((M, N), dtype=jnp.int32)
mask = jnp.tril(mask)
block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N))
def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
del prefetch_mask, prefetch_j
# Zero-out the k index if the mask is zero, to avoid constantly fetching
# new blocks in the inner loop for blocks we are skipping.
k_fetch = (block_mask[i, j] != 0) * k
return (prefetch_i[i, j], k_fetch)
def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
del prefetch_mask, prefetch_i
k_fetch = (block_mask[i, j] != 0) * k
return (k_fetch, prefetch_j[i, j])
def mask_map(i, j, k, block_mask, prefetch_mask, *_):
del k, block_mask
return (prefetch_mask[i, j], 0, 0)
def o_map(i, j, k, *_):
del k
return (i, j)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=4,
grid=(M // blk_M, N // blk_N, K // blk_K),
in_specs=[pl.BlockSpec((blk_M, blk_K), x_map),
pl.BlockSpec((blk_K, blk_N), y_map),
pl.BlockSpec((1, blk_M, blk_N), mask_map)],
out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
)
kernel = pl.pallas_call(
sparse_mask_matmul,
grid_spec=grid_spec,
out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),
)
args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
result = kernel(*args)
ref = mask * (X @ Y)
diff = jnp.abs(ref - result)
print('mean |result - ref|:', jnp.mean(diff))
mean |result - ref|: 1.0252e-05
现在,让我们将性能与朴素的稠密实现进行比较。在 TPU v5e 上,稀疏内核的性能比稠密实现提高了约 1.8 倍,而使用下三角掩码并仅访问一半可能输出的最佳理论情况是 2 倍。
我们通常期望随着输入增大,性能会更接近理论峰值,因为我们未能完全达到理论性能的几个主要原因是:
我们跳过的计算量略少于一半,因为对角线上的块是混合 0 和 1 的,对于混合块,我们需要计算整个块。随着输入增大,混合块的开销相对于总计算量变得更小。
随着输入增大,流水线气泡占总运行时间的百分比也减小。
n_trials = 100
pallas_impl = lambda *args: kernel(*args)
time = benchmark(pallas_impl, n_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
ref_impl = jax.jit(lambda mask, x, y: mask * (x @ y))
time = benchmark(ref_impl, n_trials)(mask, X, Y)
print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
Sparse Kernel: 28.648 ms (avg over 100 trials)
Reference: 49.988 ms (avg over 100 trials)