标量预取和块稀疏计算#

在本教程中,我们将介绍 Pallas 中块稀疏计算的基础知识。稀疏计算是编写自定义 Pallas 核而不是简单使用 JAX/XLA 的一个主要原因,因为由于静态数组形状,在 XLA 中通常很难表达执行动态计算量的程序。在本教程中,我们将学习如何使用 Pallas 的标量预取功能来编写可以动态跳过计算和内存块的块稀疏核。

import functools
import timeit
import numpy as np
import jax
from jax import numpy as jnp
from jax import lax
from jax.experimental import checkify
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu

assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
print("Running on", jax.devices()[0].device_kind)
Running on TPU v5 lite

配合标量预取的动态块索引#

我们将利用 Pallas 的“标量预取”功能来编写稀疏核。标量预取允许您将少量数据传入 SMEM(“标量内存”),这些数据在流水线开始之前加载(“预取”)。由于这些数据在流水线之前加载,因此可供每个 BlockSpecindex_map 使用,从而允许您执行数据相关的索引计算。本教程的主要目标是介绍利用此功能的常见编程模式。

要使用标量预取,请使用 pltpu.PrefetchScalarGridSpec 代替标准的 pl.GridSpec

class PrefetchScalarGridSpec:
  def __init__(self,
    num_scalar_prefetch: int,
    grid: tuple[int, ...],
    in_specs: PyTree[BlockSpec],
    out_specs: PyTree[BlockSpec],
    scratch_shapes: tuple[MemorySpace, ...]):
      ...

参数 num_scalar_prefetch 指示标量预取值的数量。当此参数设置为非零值时,它会改变核和索引映射的调用签名,以期望额外的预取值。传入 index_map 和核的预取 Ref 都分配在 SMEM 中,并且由于它们没有定义 BlockSpec,因此不会被划分为块。此外,index_map 和核的参数顺序始终是固定的,并将在下面描述。

  • 每个 BlockSpecindex_map 现在期望预取 Refs 在网格索引之后。

def index_map(*grid_indices, *prefetch_refs):
    ...
  • 用户定义的核期望预取 Refs 在输入 Refs 之前。此外,暂存 Refs 在输出 Refs 之后。

def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs):
    ...
  • 当使用 pallas_call 调用新核时,pallas_call 返回的函数也期望标量预取参数在输入之前,例如:

kernel = pl.pallas_call(...)
result = kernel(*prefetch_args, *input_args)

示例:配合标量预取的块动态切片#

让我们从一个基本示例开始,演示如何使用标量预取功能。我们将实现一个块对齐的动态切片核,它根据用户指定的索引从更大的数组中简单地提取一个块。

  1. 在核之外,我们计算要提取的块索引为:block_idx = (start[0] // size[0], start[1] // size[1])

  2. 我们将 block_idx 作为标量预取参数传入 pallas_call

  3. 在我们的索引映射中,我们通过返回 (block_idx[0], block_idx[1]) 来使用块索引选择对应的块。

当然,这个核是有限制的,即我们的切片大小必须适合核块内部(受 VMEM 大小限制),并且我们只能从大小对齐的索引开始。一个更高级的核会将核块大小与切片大小解耦,并允许非对齐的起始索引。

def dynamic_slice_kernel(indices, x_ref, o_ref):
  del indices
  o_ref[...] = x_ref[...]

@checkify.checkify
@functools.partial(jax.jit, static_argnums=(2,))
def block_dynamic_slice(x, starts, sizes):
  grid_spec = pltpu.PrefetchScalarGridSpec(
      num_scalar_prefetch=1,
      grid=(1, 1),
      in_specs=[pl.BlockSpec(
          sizes,
          lambda i, j, block_idx: (block_idx[0], block_idx[1]))],
      out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)),
  )

  kernel = pl.pallas_call(
    dynamic_slice_kernel,
    grid_spec=grid_spec,
    out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype),
  )
  # Checkify inserts a runtime assert that starts are divisible by block size.
  checkify.check(starts[0] % sizes[0] == 0, "Starts must be divisible by size.")
  checkify.check(starts[1] % sizes[1] == 0, "Starts must be divisible by size.")
  block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]])
  return kernel(block_idx, x)

shape = (512, 512)
x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape)
err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128))
err.throw()
ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128))
diff = jnp.max(jnp.abs(result - ref))
print("Error |result - lax.dynamic_slice| =", diff)
Error |result - lax.dynamic_slice| = 0

稀疏核:表示稀疏数据#

在深入实现稀疏核之前,我们首先回顾稀疏矩阵是如何表示的。尽管有几种流行的稀疏矩阵存储格式,但我们将遵循坐标列表格式(COO)的块变体,其中我们将矩阵存储为 (block_index, block_data) 对的列表。列表中未明确存储的所有块都被假定为零,这意味着如果矩阵中有许多零块,我们可以节省大量内存。

下图演示了我们如何将一个 4x4 的稠密矩阵(左)转换为块大小为 2x2 的块 COO 格式(右)。请注意,在稀疏格式中,我们可以避免显式存储由全零元素组成的右上角块。

block_coo

我们将使用以下辅助函数来采样块稀疏矩阵。它返回一个用于检查结果的稠密矩阵,以及一个包含每个轴的块数据和索引的列表。

def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32):
  """Returns a sampled matrix and its block-sparse representation.

  Args:
    key: RNG Key.
    M: Major array dimension.
    N: Minor array dimension.
    blk_M: Block size along M dimension.
    blk_N: Block size along N dimension.
    p: Probability that a block will be non-zero.
    dtype: dtype of the sampled matrix.

  Returns:
    dense_mat: A (M, N) dense sampled array.
    block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing
      the non-zero blocks of the matrix.
    indices_i: A (num_blocks,) array of block indices for the first axis.
    indices_j: A (num_blocks,) array of block indices for the second axis.
  """
  mask_key, blocks_key = jax.random.split(key)
  num_blocks = (M // blk_M, N // blk_N)
  # We first sample a block mask, denoting which blocks are nonzero.
  block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks)
  num_blocks = jnp.sum(block_mask)
  indices = jnp.where(block_mask)
  # For each non-zero block, we sample a block of random values.
  block_data = jax.random.uniform(blocks_key,
                                  shape=(num_blocks, blk_M, blk_N),
                                  dtype=dtype)
  # For checking purposes, create the dense version of the sparse matrix.
  dense_mat = jnp.zeros((M, N), dtype=dtype)
  for blk in range(num_blocks):
    idx_i = indices[0][blk]
    idx_j = indices[1][blk]
    slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M)
    slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N)
    dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk])
  return dense_mat, block_data, indices[0], indices[1]

示例:稀疏矩阵乘以稠密矩阵#

在我们的第一个示例中,我们将一个稀疏的左侧(LHS)矩阵与一个稠密的右侧(RHS)矩阵相乘,以产生一个稠密输出。

我们将以 2 个循环构建核网格——外层循环遍历 RHS/输出的列,内层循环遍历 LHS 的稀疏块。在每次内层循环迭代中,我们从 LHS 加载一个块,并使用收缩维度(K)的块索引查找 RHS 中对应的块。我们将这两个块相乘并累加到正确的输出块中。一个外层循环迭代将计算整个列的结果,如下图所示:

sparse_matmul

在将块索引(例如 [0, 0, 1, 2, 3, 3])传入核之前,按行对其进行分组很重要,原因有二。首先,在我们的核中,我们需要知道何时初始化输出 Ref 中的累加器为零,如果行索引已分组,则很容易做到这一点。其次,Pallas 的流水线逻辑不允许我们在非连续迭代中重新访问输出 Ref 中的块,因此我们需要在连续的核迭代中将所有累加操作完成到一个输出块中。这是因为流水线发射器会意识到我们在连续迭代中加载相同的输出块,并将该块保留在 VMEM 中。当我们更改输出块时,Pallas 最终会将输出存储到 HBM 中,并假定我们不再触及它。未能连续访问输出块将导致不正确的值,即使核在逻辑上是正确的。

M = N = K = 16384
blk_M = blk_N = blk_K = 512


def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
               x_ref, y_ref, _, o_ref, # Kernel inputs.
               accum_scratch,
               ):
  """A DSD (Dense = Sparse @ Dense) matmul kernel."""
  del idxs_k_ref
  blk_idx = pl.program_id(1)
  is_start = blk_idx == 0
  changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
  @pl.when(is_start | changed_blocks)
  def _():
    accum_scratch[...] = jnp.zeros_like(accum_scratch)
  accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32)

  next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)])
  is_end = blk_idx == (num_blocks - 1)
  @pl.when(is_end | next_block_change)
  def _():
    o_ref[...] = accum_scratch[...].astype(o_ref.dtype)


def x_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
  del j, blk_idxs_i, blk_idxs_k
  return (blk_idx, 0, 0)
def y_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
  del blk_idxs_i
  return (blk_idxs_k[blk_idx], j)
def o_map(j, blk_idx, blk_idxs_i, blk_idxs_k):
  del blk_idxs_k
  return (blk_idxs_i[blk_idx], j)

(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat(
    jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16)
num_blocks = X_blocks.shape[0]
Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
zeros = jnp.zeros((M, N), dtype=jnp.bfloat16)
out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=2,
    # Note that while num_blocks is static here, Pallas does support
    # dynamic grid sizes.
    grid=(N // blk_N, num_blocks),
    in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
              pl.BlockSpec((blk_K, blk_N), y_map),
              # Placeholder for a zeros-array used by input_output_aliases.
              pl.BlockSpec((blk_M, blk_N), o_map),
              ],
    out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
    scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
)
kernel = pl.pallas_call(
  dsd_kernel,
  grid_spec=grid_spec,
  out_shape=out_shape,
  # We use input-output aliases to zero-out o_ref for blocks that we never
  # visit. By passing in an array of zeros we avoid having o_ref start with
  # uninitialized values.
  input_output_aliases={4: 0},  # Map zeros to o_ref.
)
args = (indices_i, indices_k, X_blocks, Y, zeros)
result = kernel(*args)

ref = X_dense @ Y
diff = jnp.abs(ref - result)
print('mean |result - ref|:', jnp.mean(diff))
mean |result - ref|: 0

我们可以进行快速基准测试,比较我们的稀疏核与 JAX 中的稠密矩阵乘法的性能。在 TPU v5e 芯片上,此核与稀疏因子带来的理论 10 倍加速相比,实现了大约 6 倍的加速。

这里有一些主要的性能优化技巧,主要集中在减少 HBM/VMEM 之间的通信开销。

  • 使用 dtype=jnp.bfloat16 对性能至关重要,因为它将内存带宽减少了一半。

  • 使用更大的块大小也有助于提升性能,因为矩阵乘法是 \(O(N^3)\) 的计算和 \(O(N^2)\) 的内存操作。随着 \(N\) 变大,核会变成计算密集型。然而,在实践中,一个反驳是更小的块大小也使得数据可以更稀疏,因此这是一个需要仔细选择的参数。

# Benchmark Sparse Pallas kernel vs reference JAX implementation

def benchmark(f, ntrials: int = 100):
  def run(*args, **kwargs):
    # Compile function first
    jax.block_until_ready(f(*args, **kwargs))
    # Time function
    result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
                           number=ntrials)
    time = result / ntrials
    return time
  return run


n_trials = 100

pallas_impl = lambda *args: kernel(*args)
time = benchmark(pallas_impl, n_trials)(indices_i, indices_k, X_blocks, Y, zeros)
print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))

ref_impl = jax.jit(lambda x, y: x @ y)
time = benchmark(ref_impl, n_trials)(X_dense, Y)
print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
Sparse Kernel: 8.136 ms (avg over 100 trials)
Reference: 46.953 ms (avg over 100 trials)

稠密数据上的稀疏访问模式#

在我们之前的示例中,我们考虑了数据本身是稀疏的情况。这在核结构中表现为一个动态的核网格维度,循环遍历非零块的数量(num_blocks)。

第二种有用的编程模式出现在底层数据是稠密的情况下,但我们希望对其执行稀疏计算。在这种情况下,我们的核网格将是稠密的,但我们希望根据块稀疏掩码跳过网格中的某些块。这种编程模式常见于许多机器学习应用中使用掩码的情况,例如自注意力中的因果或局部掩码。在这些情况下,我们可以完全跳过掩码为零的块中的计算。这种编程模式的示例可以在 jax/experimental/pallas/ops/tpu 中的 Splash Attention 和 Grouped Matrix Multiplication 核中找到,或者在 PyTorch 的 FlexAttention 中找到。

处理稠密数据上的稀疏访问模式时,主要的性能考量是与流水线的交互。在任何给定的核迭代中,Pallas 流水线发射器将尝试通过在网格的下一次迭代中调用每个 BlockSpecindex_map 来预取下一块数据。然而,如果我们的计算是稀疏的,我们可能会跳过网格中下一个块的计算,因此我们需要一种方法来告诉流水线转而开始获取 *我们没有跳过的下一个块*。为此,我们需要构建 *预取映射*,其中包含每个核输入的下一个未跳过数据块的索引。下图说明了如何为以 COO 格式存储的块稀疏掩码构建预取映射。

prefetch_map

左:一种稀疏访问模式,其中蓝色表示需要计算的非零掩码块。右:预取映射,其中数组的每个元素包含下一个非零块数据的索引。

预取映射构建完成后,我们可以将其作为标量预取参数传入,并在 BlockSpec 的 index_map 函数中查询它。

def mask_index_map(prefetch_map, i, j, ...):
  next_nonzero_block = prefetch_map[i, j]
  return (next_nonzero_block, 0, 0)

我们可以为核的其他输入构建类似的索引映射。对于稠密输入,您很可能需要构建指向网格中下一个非零块索引的预取映射。我们的下一个示例将提供使用这些预取映射的示例。

示例:配合块稀疏输出掩码的稠密矩阵乘法#

在我们的下一个示例中,我们将介绍稠密矩阵乘法与稀疏输出掩码的融合,并使用预取映射来提高流水线性能。我们将使用掩码选择性地跳过计算已归零的输出块,从而节省计算成本。

由于我们将使用稀疏掩码,我们将首先实现一个函数,将以稠密格式存储的 N x M 掩码转换为块稀疏格式。我们还需要计算预取映射,以帮助流水线发射器知道下一个要获取的块。总的来说,我们的 sparsify_mask 函数计算:

  • 一个形状为 (num_N_blocks, num_M_blocks)block_mask,指示一个块是否全为零(值为 0)或包含非零元素(值为 1)。如果 block_mask 的值为 0,我们可以在核中跳过该块的计算。

  • 一个形状为 (num_N_blocks, num_M_blocks)prefetch_mask 数组,其中包含指向 mask_data 中下一个非零块的索引。

  • 一个形状为 (num_N_blocks, num_M_blocks)prefetch_i 数组,其中包含掩码的下一个非掩码 i 索引。

  • 一个形状为 (num_N_blocks, num_M_blocks)prefetch_j 数组,其中包含掩码的下一个非掩码 j 索引。

  • 一个形状为 (num_blocks, blk_N, blk_M)mask_data 数组,其中包含掩码非零块的数据。

def sparsify_mask(mask: jax.Array,
                  block_shape: tuple[int, int]):
  """Preprocesses a mask into a sparse representation.

  Args:
    mask: A boolean array of shape [M, N]
    block_shape: The size of a single block.

  Returns:
    block_mask: A block_shape array of booleans indicating whether a block
      is all-zeros (0) or contains non-zero elements (1).
    prefetch_mask: A block_shape array of integers indicating the index of the
      next non-zero block.
    mask_data: A (num_blocks, block_shape) array containing
      the data for non-zero blocks of the mask.
  """
  M, N = mask.shape
  bm, bn = block_shape

  block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)
  mask_types_finder = []
  mask_data = []

  next_mask_type_idx = 0
  prefetch_mask = jnp.zeros_like(block_mask)
  next_i = (M // bm) - 1
  next_j = (N // bn) - 1
  prefetch_i = jnp.zeros_like(block_mask)
  prefetch_j = jnp.zeros_like(block_mask)
  for i in range(M // bm, -1, -1):
    for j in range(N // bn, -1, -1):
      mask_block = mask[i * bm :(i + 1) * bm,
                        j * bn :(j + 1) * bn]
      is_nonzero = jnp.any(mask_block)
      if is_nonzero:
        try:
          type_index = mask_types_finder.index(str(mask_block))
        except ValueError:
          type_index = len(mask_types_finder)
          mask_types_finder.append(str(mask_block))
          mask_data.append(mask_block)
        next_mask_type_idx = type_index
        next_i = i
        next_j = j
      else:
        type_index = -1
      block_mask = block_mask.at[i, j].set(is_nonzero)
      prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)
      prefetch_i = prefetch_i.at[i, j].set(next_i)
      prefetch_j = prefetch_j.at[i, j].set(next_j)
  return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data)

就核的结构而言,我们使用与之前教程中介绍的标准矩阵乘法核相同的网格模式,对 NMK 维度进行 3 次循环。在核内部,我们首先检查 block_mask 以查看当前输出块的掩码是否全为零。如果掩码全为零,我们可以跳过计算并移动到下一个块;否则我们需要计算矩阵乘法,然后对结果进行掩码处理。

M = N = K = 16384
blk_M = blk_N = 512
blk_K = 1024

def sparse_mask_matmul(
    block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs.
    x_ref, y_ref, mask_ref, o_ref,  # Kernel inputs.
    accum_scratch
    ):
  del prefetch_mask, prefetch_i, prefetch_j
  i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2)
  should_compute = block_mask_ref[i, j] != 0
  @pl.when(k == 0)
  def _():
    o_ref[...] = jnp.zeros_like(o_ref)
    accum_scratch[...] = jnp.zeros_like(accum_scratch[...])

  # We only compute the output for blocks with non-zero masks.
  # Otherwise we skip the computation entirely.
  @pl.when(should_compute)
  def _():
    result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32)
    accum_scratch[...] += result
    @pl.when(k == pl.num_programs(2) - 1)
    def _():
      o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype)

X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16)
Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
mask = jnp.ones((M, N), dtype=jnp.int32)
mask = jnp.tril(mask)
block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N))

def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
  del prefetch_mask, prefetch_j
  # Zero-out the k index if the mask is zero, to avoid constantly fetching
  # new blocks in the inner loop for blocks we are skipping.
  k_fetch = (block_mask[i, j] != 0) * k
  return (prefetch_i[i, j], k_fetch)

def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
  del prefetch_mask, prefetch_i
  k_fetch = (block_mask[i, j] != 0) * k
  return (k_fetch, prefetch_j[i, j])

def mask_map(i, j, k, block_mask, prefetch_mask, *_):
  del k, block_mask
  return (prefetch_mask[i, j], 0, 0)

def o_map(i, j, k, *_):
  del k
  return (i, j)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=4,
    grid=(M // blk_M, N // blk_N, K // blk_K),
    in_specs=[pl.BlockSpec((blk_M, blk_K), x_map),
              pl.BlockSpec((blk_K, blk_N), y_map),
              pl.BlockSpec((1, blk_M, blk_N), mask_map)],
    out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
    scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
)
kernel = pl.pallas_call(
  sparse_mask_matmul,
  grid_spec=grid_spec,
  out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),
)
args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
result = kernel(*args)

ref = mask * (X @ Y)
diff = jnp.abs(ref - result)
print('mean |result - ref|:', jnp.mean(diff))
mean |result - ref|: 1.0252e-05

现在让我们将性能与朴素的稠密实现进行比较。在 TPU v5e 上,使用稀疏核我们实现了大约 1.8 倍的加速,而使用下三角掩码并仅访问一半可能输出的理论最佳情况是 2 倍。

通常,我们期望随着输入变大,性能会更接近理论峰值,因为我们未能精确达到理论性能的一些主要原因如下:

  • 我们跳过的计算量略少于一半,因为对角线上的块是 0 和 1 的混合,对于混合块,我们需要计算整个块。随着输入变大,混合块的开销相对于总计算量变得更小。

  • 随着输入变大,流水线气泡在总运行时中的占比也变得更小。

n_trials = 100

pallas_impl = lambda *args: kernel(*args)
time = benchmark(pallas_impl, n_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))

ref_impl = jax.jit(lambda mask, x, y: mask * (x @ y))
time = benchmark(ref_impl, n_trials)(mask, X, Y)
print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
Sparse Kernel: 28.648 ms (avg over 100 trials)
Reference: 49.988 ms (avg over 100 trials)