标量预取和块稀疏计算#
在本教程中,我们将介绍 Pallas 中块稀疏计算的基础知识。稀疏计算是编写自定义 Pallas 内核而不是简单使用 JAX/XLA 的主要原因,因为由于静态数组形状,通常很难在 XLA 中表达执行动态计算量的程序。在本教程中,我们将学习如何使用 Pallas 的标量预取功能来编写可以动态跳过计算和内存块的块稀疏内核。
import functools
import timeit
import numpy as np
import jax
from jax import numpy as jnp
from jax import lax
from jax.experimental import checkify
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
print("Running on", jax.devices()[0].device_kind)
Running on TPU v5 lite
使用标量预取的动态块索引#
我们将利用 Pallas 的“标量预取”功能来实现稀疏内核的编写。标量预取允许您将少量数据传递到 SMEM(“标量内存”),这些数据在管道开始之前加载(“预取”)。因为此数据在管道之前加载,所以它可以在每个 BlockSpec 的 index_map
中使用,允许您执行数据相关的索引计算。本教程的主要目标是介绍利用此功能的常见编程模式。
要使用标量预取,请使用 pltpu.PrefetchScalarGridSpec
代替标准 pl.GridSpec
class PrefetchScalarGridSpec:
def __init__(self,
num_scalar_prefetch: int,
grid: tuple[int, ...],
in_specs: PyTree[BlockSpec],
out_specs: PyTree[BlockSpec],
scratch_shapes: tuple[MemorySpace, ...]):
...
num_scalar_prefetch
参数表示标量预取值的数量。当此值设置为非零值时,它会更改内核和索引映射的调用签名,以期望额外的预取值。传递给 index_map
和内核的预取 Ref
都分配在 SMEM 中,并且由于没有定义 BlockSpec 而不会分区为块。此外,index_map
和内核的参数顺序始终是固定的,如下所述
每个
BlockSpec
的index_map
现在期望预取Ref
出现在网格索引之后
def index_map(*grid_indices, *prefetch_refs):
...
用户定义的内核期望预取
Ref
出现在输入Ref
之前。此外,暂存 ref 出现在输出Ref
之后。
def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs):
...
当使用
pallas_call
调用新内核时,pallas_call
返回的函数也期望标量预取参数出现在输入之前,例如
kernel = pl.pallas_call(...)
result = kernel(*prefetch_args, *input_args)
示例:使用标量预取的块动态切片#
让我们从一个基本示例开始,该示例演示如何使用标量预取功能。我们将实现一个块对齐的动态切片内核,该内核仅根据用户指定的索引从较大的数组中提取一个块
在内核之外,我们将要提取的块索引计算为:
block_idx = (start[0] // size[0], start[1] // size[1])
我们将
block_idx
作为标量预取参数传递到pallas_call
中。在我们的索引映射中,我们使用块索引通过返回
(block_idx[0], block_idx[1])
来选择相应的块。
当然,此内核的局限性在于我们的切片大小必须适合内核块(受 VMEM 大小的限制),并且我们只能从大小对齐的索引开始。更高级的内核会将内核块大小与切片大小解耦,并允许非对齐的起始索引。
def dynamic_slice_kernel(indices, x_ref, o_ref):
del indices
o_ref[...] = x_ref[...]
@checkify.checkify
@functools.partial(jax.jit, static_argnums=(2,))
def block_dynamic_slice(x, starts, sizes):
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
grid=(1, 1),
in_specs=[pl.BlockSpec(
sizes,
lambda i, j, block_idx: (block_idx[0], block_idx[1]))],
out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)),
)
kernel = pl.pallas_call(
dynamic_slice_kernel,
grid_spec=grid_spec,
out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype),
)
# Checkify inserts a runtime assert that starts are divisible by block size.
checkify.check(starts[0] % sizes[0] == 0, "Starts must be divisible by size.")
checkify.check(starts[1] % sizes[1] == 0, "Starts must be divisible by size.")
block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]])
return kernel(block_idx, x)
shape = (512, 512)
x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape)
err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128))
err.throw()
ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128))
diff = jnp.max(jnp.abs(result - ref))
print("Error |result - lax.dynamic_slice| =", diff)
Error |result - lax.dynamic_slice| = 0
稀疏内核:表示稀疏数据#
在我们深入研究实现稀疏内核之前,让我们首先回顾一下稀疏矩阵的表示方式。虽然有几种流行的存储稀疏矩阵的格式,但我们将遵循坐标列表格式(COO)的块变体,其中我们将矩阵存储为 (block_index, block_data)
对的列表。列表中未显式存储的所有块均假定为零,这意味着如果矩阵中有许多零块,我们可以节省大量内存。
下图演示了如何将 4x4 稠密矩阵(左)转换为块 COO 格式(右),块大小为 2x2。请注意,在稀疏格式中,我们可以避免显式存储由所有零元素组成的右上块。
我们将使用以下辅助函数来采样块稀疏矩阵。它返回用于检查结果的稠密矩阵,以及每个轴的块数据和索引列表。
def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32):
"""Returns a sampled matrix and its block-sparse representation.
Args:
key: RNG Key.
M: Major array dimension.
N: Minor array dimension.
blk_M: Block size along M dimension.
blk_N: Block size along N dimension.
p: Probability that a block will be non-zero.
dtype: dtype of the sampled matrix.
Returns:
dense_mat: A (M, N) dense sampled array.
block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing
the non-zero blocks of the matrix.
indices_i: A (num_blocks,) array of block indices for the first axis.
indices_j: A (num_blocks,) array of block indices for the second axis.
"""
mask_key, blocks_key = jax.random.split(key)
num_blocks = (M // blk_M, N // blk_N)
# We first sample a block mask, denoting which blocks are nonzero.
block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks)
num_blocks = jnp.sum(block_mask)
indices = jnp.where(block_mask)
# For each non-zero block, we sample a block of random values.
block_data = jax.random.uniform(blocks_key,
shape=(num_blocks, blk_M, blk_N),
dtype=dtype)
# For checking purposes, create the dense version of the sparse matrix.
dense_mat = jnp.zeros((M, N), dtype=dtype)
for blk in range(num_blocks):
idx_i = indices[0][blk]
idx_j = indices[1][blk]
slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M)
slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N)
dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk])
return dense_mat, block_data, indices[0], indices[1]
示例:稀疏 @ 稠密矩阵乘法#
在我们的第一个示例中,我们将一个稀疏的 LHS 矩阵与一个稠密的 RHS 矩阵相乘,以生成一个稠密的输出。
我们将使用 2 个循环来构造我们的内核网格 - 外循环遍历 RHS/输出的列,内循环遍历 LHS 的稀疏块。在每次内循环迭代期间,我们从 LHS 加载一个块,并使用收缩维度 (K) 的块索引在 RHS 中查找相应的块。我们将两个块相乘,并累积到正确的输出块中。一个外循环迭代将计算整列的结果,如下图所示
重要的是,我们按行对块索引进行分组(例如 [0, 0, 1, 2, 3, 3]
),然后再将它们传递到内核中,原因有两个。首先,在我们的内核中,我们需要知道何时在输出 ref 中初始化累加器为零,如果行索引分组,则很容易做到这一点。其次,Pallas 的流水线逻辑不允许我们在非连续迭代中重新访问输出 Ref
中的块,因此我们需要在连续的内核迭代中完成对输出块的所有累积。这是因为流水线发射器将意识到我们在连续迭代中加载相同的输出块,并将该块保留在 VMEM 中。当我们更改输出块时,Pallas 最终会将输出存储到 HBM 中,并假设我们永远不会再次接触它。即使内核在逻辑上是正确的,不连续地访问输出块也会导致不正确的值。
M = N = K = 16384
blk_M = blk_N = blk_K = 512
def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
x_ref, y_ref, _, o_ref, # Kernel inputs.
accum_scratch,
):
"""A DSD (Dense = Sparse @ Dense) matmul kernel."""
del idxs_k_ref
blk_idx = pl.program_id(0)
is_start = blk_idx == 0
changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
@pl.when(is_start | changed_blocks)
def _():
accum_scratch[...] = jnp.zeros_like(accum_scratch)
accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32)
next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)])
is_end = blk_idx == (num_blocks - 1)
@pl.when(is_end | next_block_change)
def _():
o_ref[...] = accum_scratch[...].astype(o_ref.dtype)
def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
del j, blk_idxs_i, blk_idxs_k
return (blk_idx, 0, 0)
def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
del blk_idxs_i
return (blk_idxs_k[blk_idx], j)
def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
del blk_idxs_k
return (blk_idxs_i[blk_idx], j)
(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat(
jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16)
num_blocks = X_blocks.shape[0]
Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
zeros = jnp.zeros((M, N), dtype=jnp.bfloat16)
out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
# Note that while num_blocks is static here, Pallas does support
# dynamic grid sizes.
grid=(num_blocks, N // blk_N),
in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
pl.BlockSpec((blk_K, blk_N), y_map),
# Placeholder for a zeros-array used by input_output_aliases.
pl.BlockSpec((blk_M, blk_N), o_map),
],
out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
)
kernel = pl.pallas_call(
dsd_kernel,
grid_spec=grid_spec,
out_shape=out_shape,
# We use input-output aliases to zero-out o_ref for blocks that we never
# visit. By passing in an array of zeros we avoid having o_ref start with
# uninitialized values.
input_output_aliases={4: 0}, # Map zeros to o_ref.
)
args = (indices_i, indices_k, X_blocks, Y, zeros)
result = kernel(*args)
ref = X_dense @ Y
diff = jnp.abs(ref - result)
print('mean |result - ref|:', jnp.mean(diff))
mean |result - ref|: 0
我们可以做一个快速基准测试,以比较我们的稀疏内核与 JAX 中的稠密 matmul 的性能。在 TPU v5e 芯片上,与理论上由稀疏因子产生的 10 倍速度提升相比,此内核的速度提升约为 6 倍。
这里有一些主要的性能提示,主要围绕减少 HBM/VMEM 之间的通信开销
使用
dtype=jnp.bfloat16
对于性能至关重要,因为它将内存带宽减少了一半。使用更大的块大小也有帮助,因为矩阵乘法是 \(O(N^3)\) 计算和 \(O(N^2)\) 内存操作。随着 \(N\) 的增大,内核变为计算密集型。但是,在实践中,对此的反驳是较小的块大小也使数据更稀疏,因此这是一个应仔细选择的参数。
# Benchmark Sparse Pallas kernel vs reference JAX implementation
def benchmark(f, ntrials: int = 100):
def run(*args, **kwargs):
# Compile function first
jax.block_until_ready(f(*args, **kwargs))
# Time function
result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
number=ntrials)
time = result / ntrials
return time
return run
n_trials = 100
pallas_impl = lambda *args: kernel(*args)
time = benchmark(pallas_impl, n_trials)(indices_i, indices_k, X_blocks, Y, zeros)
print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
ref_impl = jax.jit(lambda x, y: x @ y)
time = benchmark(ref_impl, n_trials)(X_dense, Y)
print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
Sparse Kernel: 8.136 ms (avg over 100 trials)
Reference: 46.953 ms (avg over 100 trials)
稠密数据上的稀疏访问模式#
在我们之前的示例中,我们考虑了数据本身是稀疏的情况。这在内核结构中表现为内核网格中的一个维度,它是动态的,并且循环遍历非零块的数量(num_blocks
)。
当基础数据是稠密的,但我们希望对其执行稀疏计算时,会出现第二种有用的编程模式。在这种情况下,我们的内核网格将是稠密的,但我们希望根据块稀疏掩码跳过网格中的某些块。当在许多机器学习应用程序中使用掩码时,例如自注意力中的因果掩码或局部掩码时,通常会出现这种编程模式。在这些情况下,我们可以完全跳过掩码归零的块中的计算。可以在 jax/experimental/pallas/ops/tpu
中的 Splash Attention 和分组矩阵乘法内核中找到这种编程模式的示例,或者在 PyTorch 的 FlexAttention 中找到。
在处理稠密数据上的稀疏访问模式时,主要的性能考虑因素是与流水线的交互。在任何给定的内核迭代中,Pallas 流水线发射器将尝试通过在网格的下一次迭代中为每个 BlockSpec
调用 index_map
来预取下一个数据块。但是,如果我们的计算是稀疏的,我们可能会跳过网格中下一个块的计算,因此我们需要某种方法来告诉流水线开始获取我们不跳过的下一个块。为了做到这一点,我们需要构造预取映射,其中包含每个内核输入的下一个非跳过数据块的索引。下图说明了如何为以类似 COO 的格式存储的块稀疏掩码构建预取映射。
左图:稀疏访问模式,其中蓝色表示我们需要计算的具有非零掩码的块。右图:预取映射,其中数组的每个元素都包含下一个非零块数据的索引。
构造预取映射后,我们可以将该映射作为标量预取参数传递,并在 BlockSpec 的 index_map
函数中查询它。
def mask_index_map(prefetch_map, i, j, ...):
next_nonzero_block = prefetch_map[i, j]
return (next_nonzero_block, 0, 0)
我们可以为内核的其他输入构造类似的索引映射。对于稠密输入,您很可能需要构造预取映射,这些映射指向网格中下一个非零块索引。我们的下一个示例将提供使用这些预取映射的示例。
示例:使用块稀疏输出掩码的稠密 @ 稠密矩阵乘法#
在下一个示例中,我们将介绍密集矩阵乘法与稀疏输出掩码的融合,并使用预取映射来提高流水线性能。我们将使用掩码来有选择地跳过计算那些被置零的输出块,从而节省计算成本。
由于我们将使用稀疏掩码,我们将首先实现一个函数,将以密集格式存储的 N x M
掩码转换为块稀疏格式。我们还需要计算预取映射,以帮助流水线发射器知道接下来要获取哪个块。总而言之,我们的 sparsify_mask
函数会计算:
一个形状为
(num_N_blocks, num_M_blocks)
的block_mask
,指示一个块是否全为零(值为0
)或者包含非零元素(值为1
)。如果block_mask
的值为 0,我们可以跳过内核中该块的计算。一个形状为
(num_N_blocks, num_M_blocks)
的prefetch_mask
数组,其中包含mask_data
中下一个非零块的索引。一个形状为
(num_N_blocks, num_M_blocks)
的prefetch_i
数组,其中包含掩码的下一个非掩码i
索引。一个形状为
(num_N_blocks, num_M_blocks)
的prefetch_j
数组,其中包含掩码的下一个非掩码j
索引。一个形状为
(num_blocks, blk_N, blk_M)
的mask_data
数组,其中包含掩码的非零块的数据。
def sparsify_mask(mask: jax.Array,
block_shape: tuple[int, int]):
"""Preprocesses a mask into a sparse reprentation.
Args:
mask: A boolean array of shape [M, N]
block_shape: The size of a single block.
Returns:
block_mask: A block_shape array of booleans indicating whether a block
is all-zeros (0) or contains non-zero elements (1).
prefetch_mask: A block_shape array of integers indicating the index of the
next non-zero block.
mask_data: A (num_blocks, block_shape) array containing
the data for non-zero blocks of the mask.
"""
M, N = mask.shape
bm, bn = block_shape
block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)
mask_types_finder = []
mask_data = []
mask_type_idxs = []
next_mask_type_idx = 0
prefetch_mask = jnp.zeros_like(block_mask)
next_i = (M // bm) - 1
next_j = (N // bn) - 1
prefetch_i = jnp.zeros_like(block_mask)
prefetch_j = jnp.zeros_like(block_mask)
for i in range(M // bm, -1, -1):
for j in range(N // bn, -1, -1):
mask_block = mask[i * bm :(i + 1) * bm,
j * bn :(j + 1) * bn]
is_nonzero = jnp.any(mask_block)
if is_nonzero:
try:
type_index = mask_types_finder.index(str(mask_block))
except ValueError:
type_index = len(mask_types_finder)
mask_types_finder.append(str(mask_block))
mask_data.append(mask_block)
next_mask_type_idx = type_index
next_i = i
next_j = j
else:
type_index = -1
mask_type_idxs.append(type_index)
block_mask = block_mask.at[i, j].set(is_nonzero)
prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)
prefetch_i = prefetch_i.at[i, j].set(next_i)
prefetch_j = prefetch_j.at[i, j].set(next_j)
return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data)
在内核的结构方面,我们使用与之前教程中介绍的标准矩阵乘法内核相同的网格模式,在 N
、M
和 K
维度上使用 3 个循环。在内核内部,我们首先检查 block_mask
,查看当前输出块的掩码是否全为零。如果掩码全为零,我们可以跳过计算并移动到下一个块;否则我们需要计算矩阵乘法,然后对结果进行掩码。
M = N = K = 16384
blk_M = blk_N = 512
blk_K = 1024
def sparse_mask_matmul(
block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs.
x_ref, y_ref, mask_ref, o_ref, # Kernel inputs.
accum_scratch
):
del prefetch_mask, prefetch_i, prefetch_j
i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2)
should_compute = block_mask_ref[i, j] != 0
@pl.when(k == 0)
def _():
o_ref[...] = jnp.zeros_like(o_ref)
accum_scratch[...] = jnp.zeros_like(accum_scratch[...])
# We only compute the output for blocks with non-zero masks.
# Otherwise we skip the computation entirely.
@pl.when(should_compute)
def _():
result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32)
accum_scratch[...] += result
@pl.when(k == pl.num_programs(2) - 1)
def _():
o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype)
X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16)
Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
mask = jnp.ones((M, N), dtype=jnp.int32)
mask = jnp.tril(mask)
block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N))
def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
del prefetch_mask, prefetch_j
# Zero-out the k index if the mask is zero, to avoid constantly fetching
# new blocks in the inner loop for blocks we are skipping.
k_fetch = (block_mask[i, j] != 0) * k
return (prefetch_i[i, j], k_fetch)
def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
del prefetch_mask, prefetch_i
k_fetch = (block_mask[i, j] != 0) * k
return (k_fetch, prefetch_j[i, j])
def mask_map(i, j, k, block_mask, prefetch_mask, *_):
del k, block_mask
return (prefetch_mask[i, j], 0, 0)
def o_map(i, j, k, *_):
del k
return (i, j)
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=4,
grid=(M // blk_M, N // blk_N, K // blk_K),
in_specs=[pl.BlockSpec((blk_M, blk_K), x_map),
pl.BlockSpec((blk_K, blk_N), y_map),
pl.BlockSpec((1, blk_M, blk_N), mask_map)],
out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
)
kernel = pl.pallas_call(
sparse_mask_matmul,
grid_spec=grid_spec,
out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),
)
args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
result = kernel(*args)
ref = mask * (X @ Y)
diff = jnp.abs(ref - result)
print('mean |result - ref|:', jnp.mean(diff))
mean |result - ref|: 1.0252e-05
现在,让我们将性能与朴素的密集实现进行比较。在 TPU v5e 上,与使用下三角掩码并且仅访问一半可能的输出的理论最佳情况 2 倍相比,我们使用稀疏内核实现了大约 1.8 倍的速度提升。
我们通常期望随着输入变得更大,性能会更接近理论峰值,因为我们没有完全达到理论性能的几个主要原因是:
我们跳过的计算略少于一半,因为沿对角线的块是 0 和 1 的混合,对于混合块,我们需要计算整个块。对于较大的输入,混合块的开销相对于整体计算变得更小。
随着输入变大,流水线气泡也占总运行时间的比例更小。
n_trials = 100
pallas_impl = lambda *args: kernel(*args)
time = benchmark(pallas_impl, n_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
ref_impl = jax.jit(lambda mask, x, y: mask * (x @ y))
time = benchmark(ref_impl, n_trials)(mask, X, Y)
print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
Sparse Kernel: 28.648 ms (avg over 100 trials)
Reference: 49.988 ms (avg over 100 trials)