SparseCore 内核编写#
SparseCore 专门用于稀疏内存访问和运算,在多个版本的现代 TPU 中一直是核心组件。虽然大多数矩阵乘法和繁重的计算任务将在 TensorCore 上执行,但将部分计算卸载到 SparseCore 可以提高整体性能。
本指南将概述 SparseCore 架构,并展示如何编写在 TPU SparseCore 上运行的 Pallas 内核。
from functools import partial
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.pallas import tpu_sc as plsc
import jax.numpy as jnp
import numpy as np
assert pltpu.get_tpu_info().sparse_core is not None, "No SparseCore found"
硬件概览#
根据版本的不同,TPU 芯片可能包含 2 或 4 个 SparseCore。每个 SparseCore 由多个子核心(subcore)组成,每个子核心拥有独立的 VMEM 空间。下图展示了 TPU 内部 SparseCore 的结构。
各组件说明
向量子核心 (tiles):SparseCore 的向量处理子核心。每个子核心拥有独立的内存,因此数据流是相互独立的。
通道 (Lanes, SIMD 宽度):SC 向量子核心以“单指令多数据”(SIMD) 方式对大小为 N 的向量进行计算。单条指令即可对通道内的所有数值进行计算。
标量子核心:SparseCore 的标量处理子核心。它能够执行标量运算、动态索引以及发起 DMA 和流操作。
内存空间:每个向量子核心都有自己的 VMEM 和 SMEM(图中未显示)空间。它们还可以访问共享的 VMEM 空间。标量子核心拥有自己的 SMEM。所有这些空间都与 TPU 的 HBM 相连。
在 Pallas 中,VMEM 空间表示为
pltpu.VMEM和pltpu.VMEM_SHARED,SMEM 表示为pltpu.SMEM。在其他文档中,共享 VMEM 通常被称为“SPMEM”,而每个子核心的 VMEM 被称为“TileSPMEM”或“本地 SPMEM”。
具体规格因 TPU 版本而异。以下是一些已发布的 TPU 规格:
属性 |
TPU v4 |
TPU v5p |
TPU v6e (Trillium) |
TPU 7x (Ironwood) |
|---|---|---|---|---|
SparseCore / 芯片 |
4 |
4 |
2 |
2 (4 个物理核心) |
向量子核心 / SparseCore |
16 |
16 |
16 |
16 |
SIMD 宽度 |
8 |
8 |
8 (F32) |
16 (F32) |
HBM 容量 |
32 GiB |
96 GiB |
32 GiB |
192 GB |
你也可以使用 pltpu.get_tpu_info() 快速获取当前硬件的规格信息。
# Quick way to query basic SC info
assert (sc_info := pltpu.get_tpu_info().sparse_core)
print(f"SparseCore info for TPU {pltpu.get_tpu_info().chip_version}:")
print(sc_info)
SparseCore info for TPU 7x:
SparseCoreInfo(num_cores=2, num_subcores=16, num_lanes=16, dma_granule_size_bytes=64)
操作与工作负载#
SparseCore 由 16 个小型处理单元组成,每个单元都有自己的数据流。这使得它非常适合具有以下特征的工作负载:
高度并行且不规则
随机数据访问
中低计算量
频繁的数据通信
SparseCore 上一些有用的操作包括:
小型向量算术
收集 (Gather) 和散射 (Scatter)(索引取值与发送)
排序、去重、计数、直方图
不规则 (Ragged) 操作
表达 SparseCore 硬件#
与 TensorCore 类似,Pallas 使用 mesh(网格)来表达 SparseCore 中的计算单元。根据你想要使用的处理单元,创建 ScalarSubcoreMesh 或 VectorSubcoreMesh。
请注意,VectorSubcoreMesh 有两个维度——core(对应不同的 SparseCore)和 subcore(对应每个 SparseCore 上的多个子核心)。
这允许你应用与 TensorCore 相同的编程模型来编写 SparseCore 上的集合通信(collectives)。如果想了解更多信息,请查阅我们的集合通信指南。
scalar_mesh = plsc.ScalarSubcoreMesh(
axis_name="core", num_cores=sc_info.num_cores
)
print(scalar_mesh)
vector_mesh = plsc.VectorSubcoreMesh(
core_axis_name="core", subcore_axis_name="subcore"
)
print(vector_mesh)
ScalarSubcoreMesh(axis_name='core', num_cores=2)
VectorSubcoreMesh(core_axis_name='core', subcore_axis_name='subcore', num_cores=2, num_subcores=16)
基础 SparseCore 内核#
请参见下方的简单标量子核心内核示例,其中包含 DMA、核心自定义和计算操作。请注意,标量子核心只能执行标量运算。
@jax.jit
def cumsum(x):
@pl.kernel(
out_shape=x,
mesh=scalar_mesh,
scratch_shapes=[
pltpu.SMEM((x.shape[1],), x.dtype),
pltpu.SemaphoreType.DMA,
],
)
def kernel(x_ref, o_ref, tmp_ref, sem):
idx = jax.lax.axis_index('core')
pltpu.async_copy(x_ref.at[idx], tmp_ref, sem).wait()
@pl.loop(1, x.shape[1])
def _(i):
tmp_ref[i] += tmp_ref[i - 1]
pltpu.async_copy(tmp_ref, o_ref.at[idx], sem).wait()
return kernel(x)
x_shape = (sc_info.num_cores, sc_info.num_lanes)
x = jax.random.randint(jax.random.key(0), x_shape, 0, 64, jnp.int32)
np.testing.assert_array_equal(cumsum(x), jnp.cumsum(x, axis=1))
SparseCore 内核中的流水线#
你可以使用 pltpu.emit_pipeline 来编写流水线化的 SparseCore 内核。emit_pipeline 的 core_axis_name 和 dimension_semantics 参数支持在 SparseCore/子核心之间划分流水线。
SC_REG_OP_SHAPE = (1, sc_info.num_lanes)
dma_block = (8, 128)
@jax.jit
def sc_add_one(x):
@pl.kernel(out_shape=x, mesh=vector_mesh, scratch_shapes=[])
def sc_add_one_kernel(x_hbm_ref, o_hbm_ref):
in_shape = x_hbm_ref.shape
def sc_add_one_body(in_vmem, out_vmem):
@pl.loop(0, in_vmem.shape[0], step=SC_REG_OP_SHAPE[0])
def _(c0):
@pl.loop(0, in_vmem.shape[1], step=SC_REG_OP_SHAPE[1])
def _(c1):
slc = (pl.ds(c0, SC_REG_OP_SHAPE[0]), pl.ds(c1, SC_REG_OP_SHAPE[1]))
out_vmem.at[*slc][...] = in_vmem.at[*slc][...] + 1
pltpu.emit_pipeline(
sc_add_one_body,
grid=(in_shape[0] // dma_block[0], in_shape[1] // dma_block[1]),
in_specs=[
pl.BlockSpec(block_shape=dma_block, index_map=lambda i, j: (i, j))
],
out_specs=[
pl.BlockSpec(block_shape=dma_block, index_map=lambda i, j: (i, j))
],
core_axis_name=('core', 'subcore'),
dimension_semantics=(pltpu.PARALLEL, pltpu.PARALLEL),
)(x_hbm_ref, o_hbm_ref)
return sc_add_one_kernel(x)
x = jax.random.randint(jax.random.key(0), (4096, 128), 0, 64, jnp.int32)
y = sc_add_one(x)
np.testing.assert_array_equal(y, x + 1)
或者,你可以使用 axis_index 计算核心索引,并将其用于跨核心分配任务(示例在此处)。
重叠执行 TensorCore 和 SparseCore#
重叠执行 TensorCore 和 SparseCore 内核非常简单:只需将它们放在同一个 jax.jit 中即可。XLA 编译器会自动处理它们的调度。
@jax.jit
def tc_add_one(x):
return x + 1
np.testing.assert_array_equal(tc_add_one(x), jnp.add(x, 1))
@jax.jit
def two_add_ones(x):
return sc_add_one(x), tc_add_one(x)
jax.tree.map(np.testing.assert_array_equal, two_add_ones(x), (x + 1, x + 1));
此处的基准测试显示,总执行时间小于两个函数分别运行的时间之和。
%timeit sc_add_one(x).block_until_ready()
%timeit tc_add_one(x).block_until_ready()
%timeit jax.block_until_ready(two_add_ones(x))
120 µs ± 2.46 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
113 µs ± 5.61 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
199 µs ± 2.24 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
收集与散射#
SparseCore 拥有针对索引取值和更新的特定优化操作。给定 HBM 中的输入或输出引用(命名为 data)以及 VMEM 中的索引数组(命名为 indices),它可以快速从 data[indices] 读取(“收集”)或写入(“散射”)。
我们可以通过在 async_copy 或 sync_copy 中将 Ref 用 indices Ref 进行索引来使用这些 gather/scatter 操作。例如,sync_copy(data_ref.at[indices_ref], target_ref) 将触发一次收集操作。
下方是一个将索引加载到向量子核心 VMEM 的流水线内核。在内核主体中,我们使用这些索引执行收集操作。
batch_size = 4096
value_dim = 128
gather_window_size = 128
num_steps = 1024
sc_num_cores, sc_num_subcores = sc_info.num_cores, sc_info.num_subcores
num_indices = gather_window_size * sc_num_cores * sc_num_subcores * num_steps
x = jnp.arange(batch_size * value_dim).reshape(batch_size, value_dim)
indices = jax.random.randint(
jax.random.key(0), (num_indices,), 0, batch_size, jnp.int32
)
@jax.jit
def gather(x, indices):
indices = indices.reshape((1, num_indices))
@pl.kernel(
out_shape=jax.ShapeDtypeStruct((num_indices, value_dim), x.dtype),
mesh=vector_mesh,
)
def kernel(x_hbm, i_hbm, o_hbm):
def body(i_vmem, o_vmem):
pltpu.sync_copy(x_hbm.at[i_vmem.at[0]], o_vmem) # The gather op
pltpu.emit_pipeline(
body,
grid=(num_indices // gather_window_size,),
in_specs=[
pl.BlockSpec((1, gather_window_size), index_map=lambda i: (0, i))
],
out_specs=[
pl.BlockSpec(
(gather_window_size, value_dim), index_map=lambda i: (i, 0)
)
],
core_axis_name='subcore',
dimension_semantics=(pltpu.PARALLEL,),
)(i_hbm, o_hbm)
return kernel(x, indices)
out = gather(x, indices)
np.testing.assert_array_equal(out, jnp.take(x, indices, axis=0))
如果你在内核开始时进行索引取值,可以在顶层的 pl.pallas_call 中对 plsc.BlockSpec 使用 indexed_by 和 indexed_dim 参数,将另一个输入作为该轴上当前输入的索引。
此调用将并行化从 HBM 到 VMEM 的 DMA 以及执行索引查找的收集操作,从而产生 4 个流水线阶段:索引拷贝输入、收集、内核计算以及输出拷贝输出。这允许你将收集操作与收集结果后的任何进一步计算进行重叠。
请注意,plsc.BlockSpec 尚处于实验阶段,未来可能会有变动。
@jax.jit
def gather_add_one(x, indices):
@partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct((num_indices, value_dim), x.dtype),
grid=(num_indices // gather_window_size,),
in_specs=(
plsc.BlockSpec(
(gather_window_size, value_dim), indexed_by=1, indexed_dim=0
),
pl.BlockSpec((gather_window_size,), lambda i: i),
),
out_specs=pl.BlockSpec((gather_window_size, value_dim), lambda i: (i, 0)),
compiler_params=pltpu.CompilerParams(
kernel_type=pltpu.CoreType.SC_VECTOR_SUBCORE,
dimension_semantics=(pltpu.PARALLEL,),
),
)
def kernel(gathered_ref, _, o_ref):
# gathered_ref is the gathered content of x[indices]
@pl.loop(0, gather_window_size)
def _(c0):
@pl.loop(0, o_ref.shape[1], step=16)
def _(c1):
slc = (pl.ds(c0, 1), pl.ds(c1, 16))
o_ref.at[*slc][...] = gathered_ref.at[*slc][...] + 1
return kernel(x, indices)
out = gather_add_one(x, indices)
np.testing.assert_array_equal(out, jnp.take(x, indices, axis=0) + 1)
散射(索引覆盖)是收集的逆操作。请查看下方示例内核。
@jax.jit
def scatter(x, indices):
indices = indices.reshape((1, num_indices))
@pl.kernel(
out_shape=jax.ShapeDtypeStruct((batch_size, value_dim), x.dtype),
mesh=vector_mesh,
scratch_shapes=[],
)
def kernel(x_hbm, i_hbm, o_hbm):
def body(x_vmem, i_vmem):
pltpu.sync_copy(x_vmem, o_hbm.at[i_vmem.at[0]]) # The scatter op
pltpu.emit_pipeline(
body,
grid=(num_indices // gather_window_size,),
in_specs=[
pl.BlockSpec(
(gather_window_size, value_dim), index_map=lambda i: (i, 0)
),
pl.BlockSpec(
(
1,
gather_window_size,
),
index_map=lambda i: (0, i),
),
],
out_specs=[],
core_axis_name='subcore',
dimension_semantics=(pltpu.PARALLEL,),
)(x_hbm, i_hbm)
return kernel(x, indices)
gathered = jnp.take(x, indices, axis=0)
out = scatter(gathered, indices)
np.testing.assert_array_equal(out, x)
与 TensorCore 的基准测试对比#
SparseCore 在收集和散射操作上表现尤为出色。我们可以使用原生 JAX API 实现相同功能(默认在 TensorCore 上运行),并对比结果。
%timeit jax.block_until_ready(gather(x, indices))
gather_tc = jax.jit(lambda x, i: jnp.take(x, i, axis=0))
gather_tc(x, indices).block_until_ready()
%timeit jax.block_until_ready(gather_tc(x, indices))
4.05 ms ± 2.02 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
18.1 ms ± 5.24 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)