使用 Pallas 编写 TPU 内核#
本页面重点介绍尝试在 Google TPU 上运行 Pallas 内核时重要的细节。首先,TPU 后端仍处于实验阶段,并且只支持 JAX NumPy 的一个子集。此外,为 TPU 编写高性能代码可能需要仔细考虑硬件的本机功能。虽然许多不适合硬件的模式会被接受,但它们最终可能需要软件仿真,从而降低计算速度。
警告
此功能仍应视为实验性,因为工作仍在进行中(特别是改进错误消息)。
注意
虽然此处描述的所有功能都是实验性的,但我们仍然非常重视维护其正确性。因此,在尝试编写 TPU 内核时,遇到“未实现”错误可能并不少见。但是,如果编译器接受了某个内核,它必须返回预期结果。
如果您看到意外输出,请将其与在 pallas_call
中传递了 interpret=True
的内核运行结果进行比较。如果结果出现差异,请提交错误报告。
什么是 TPU?#
TPU 是 Google 开发的硬件加速器。您可以将 TPU 视为 GPU,但专门用于机器学习工作负载。因此,它们的架构差异很大。然而,我们相信 Pallas 可以让您轻松开始编写 TPU 内核,即使您不完全了解底层硬件。话虽如此,深入了解硬件肯定会更容易编写高性能内核。
简而言之,TPU 和 GPU 的主要区别在于 TPU 是具有非常宽的向量寄存器(有点像 CPU!)的顺序机器。同时,它们允许软件在后台调度某些操作,使其与主指令流异步执行。这包括 HBM 内存访问(不能直接发出,而是必须通过 DMA 子单元预取到内存层次结构的较低级别)、矩阵乘法(由 MXU 单元支持)或矩阵转置和置换(由 XLU 单元支持)等操作。
如果您有兴趣详细了解 TPU 架构,我们建议您阅读多年来发表的一系列论文。虽然其中许多论文谈论的是特定的 TPU 代,但所描述的许多想法也适用于后来的代。
值得注意的属性和限制#
BlockSpec
和网格迭代#
BlockSpec
(参见BlockSpec,即如何分块输入)在 Pallas 中的行为通常符合预期 — 内核体的每次调用都会访问输入的切片,并旨在初始化输出的切片。
注意
- 并非所有块形状都受支持。在 TPU 上,仅支持秩至少为 1 的块。
此外,您的块形状的最后两个维度必须分别可被 8 和 128 整除,或者与整个数组的相应维度相等。
Pallas TPU 内核的一个有趣之处在于它们处理内存空间的方式:虽然 pallas_call
的输入通常驻留在 HBM(主 TPU 内存)中,但传递给内核体的引用将指向内存层次结构较低层(VMEM 或 SMEM)中的缓冲区。这使得内核体能够以非常高的速度读写它们,而与 HBM 的所有通信(具有非常高的延迟)都由编译器处理并与计算重叠。
更重要的是,与 GPU 相比,TPU 实际上是高度顺序的机器。因此,网格通常不是并行处理,而是按字典序顺序处理(但例外情况请参见多核 TPU 配置部分)。这解锁了一些有趣的功能
当两个(字典序上)连续的网格索引使用输入的相同切片时,第二次迭代的 HBM 传输将被跳过,因为数据已经可用。
内核体的多次调用可以写入输出的相同切片,而不会有任何竞态条件的风险。但是,我们要求所有写入特定切片的调用都是连续的。
输出上的“连续”限制通常意味着网格维度的某个前缀总是改变调用需要访问的输出切片,而输出窗口对于剩余的后缀保持不变。
例如,当为矩阵乘法实现 Pallas TPU 内核时,通常会使用 3 维网格:前两个维度对应于沿左操作数的第一轴和第二个操作数的第二轴进行切片。第三个也是最后一个网格轴将对规约维度进行分块。与规约维度对应的网格轴必须是最后一个,因为输出窗口不沿此轴变化。然后输出引用可以用作部分结果的累加器。
注意
VMEM 对于如此低级的内存层次结构来说相当大(16MB+),这使得可以使用大窗口大小。而且,通常,窗口大小越大,最终的硬件利用率就越高。但是,可以指定一个(连同保存溢出向量寄存器所需的空间一起)超过 VMEM 大小的窗口大小。在这种情况下,您可能会看到低级编译器错误消息抱怨内存不足。
数组布局#
数组的维度顺序在 Pallas 中具有重要意义。在 JAX 程序中,jax.jit
内部中间数组的顺序通常对性能没有影响,因为编译器可以自由地重新排列它们。然而,由于 Pallas 旨在暴露更低级别的功能,维度顺序可能对生成代码的质量产生巨大影响。
TPU 在 2D 向量寄存器上执行大部分计算,对于 32 位值,这些寄存器通常大小为 8x128(截至 TPU v6)。当向量值从 VMEM 加载到寄存器中(例如 x = x_ref[...]
)时,数组的最后两个维度将被瓦片化到寄存器中。Pallas 将只考虑将中间数组的最后两个维度映射到 8x128 向量寄存器维度(分别为 sublanes 和 lanes)。
下面是使用 6 个 8x128 瓦片如何对 12x320 数组进行瓦片化的图形示例
瓦片化布局对内核编写者有几个重要的影响
数组的最后两个轴与其它轴的处理方式不同。例如,规约、重塑和转置在涉及最后两个轴时通常成本更高。一些涉及最后两个维度的重塑不受支持,并会导致编译器错误,但对于其它维度则是“免费”并在编译时执行。
虽然有时无法避免,但最后两个轴中存在单例维度通常是浪费的,因为它们将占用整个瓦片维度中的 1 个元素。消耗过多寄存器还可能导致寄存器溢出到 VMEM 中,从而降低内核性能。
与上述观点相关,所有向量计算都会填充到瓦片大小。添加两个 1x1 数组的成本与添加两个 8x128 数组的成本相同,添加两个 8x128x1x1 数组的成本将是添加两个 8x128 数组的 1024 倍,因为 8x128x1x1 数组将被填充到 8x128x8x128。
多核 TPU 配置#
在较新的 TPU 代中,芯片上的两个核心通常被抽象为一个设备。为了利用多核,Pallas 必须打破顺序网格执行保证,并且需要将其中一个网格轴并行化到多个核心。这是一个选择加入的过程。为此,pallas_call
需要一个名为 dimension_semantics
的额外参数
pallas_call(
...,
compiler_params=pltpu.CompilerParams(
dimension_semantics=["parallel", "parallel", "arbitrary"]
),
)
该参数是一个列表,其中包含的条目数量与网格中的轴数量相同。只有 parallel
维度可以在核心上进行分区。根据经验法则,除非输出窗口不变化,否则维度是并行的。因此,dimension_semantics
始终是一定数量的 parallel
轴,后跟一定数量的 arbitrary
轴。
虽然在 2 核 TPU 设备上划分内核通常会带来 2 倍的加速,但实际上可能会显著减小。如果主体不同实例的成本差异很大,则尤其如此。如果所有昂贵的步骤都映射到一个核心,而所有廉价的步骤都分配给另一个核心,则第二个核心将一直处于空闲状态,直到第一个核心完成其任务。
Pallas TPU 通常倾向于对大小是 TPU 核心数倍数的轴进行分区,并优先对前导网格轴进行分区。
将操作数放置在 SMEM 中#
TPU 上的大部分计算将在向量单元上进行。尽管如此,在许多情况下,执行一些标量操作很有用,例如,执行控制流。因此,TPU 配备了一个单独的标量单元和一个连接到它的独立标量内存 (SMEM)。根据经验法则,任何用于执行控制流决策的数据都应放置在 SMEM 中。
SMEM 是一种低延迟内存,支持随机访问,但只允许您通过一条指令读写 32 位值(与 VMEM 事务的 4KBi 粒度相比非常小,但由于没有对齐要求而更加灵活!)。
当实现不以规则模式访问输入瓦片的内核时,例如在编写块稀疏内核时,标量内存也非常有用。在 Pallas 中,这可以通过将 pallas_call
的 grid
参数替换为 grid_spec
的 PrefetchScalarGridSpec
(带非零 num_scalar_prefetch
参数)来实现。如果 num_scalar_prefetch
为 n
,则 pallas_call
的前 n
个参数将放置在 SMEM 中。不应为这些参数指定 BlockSpec
。但是,所有后续参数的 BlockSpec
不仅会接收网格索引,还会接收前导操作数的 SMEM 引用。
有关使用此功能的示例,请参阅标量预取和块稀疏计算。
支持的数据类型#
目前 Pallas TPU 支持以下数据类型
jnp.float32
jnp.bfloat16
jnp.int*
(所有精度,除了jnp.int4
)jnp.uint*
(所有精度)jnp.bool_
计算放置#
所有标量(即 0D)数组将存储在标量寄存器中,并对其执行的操作将在标量核心上执行。所有其他操作(即使是单元素但 1D+ 数组上的操作)都将在向量核心上执行。
支持的操作#
矩阵乘法#
矩阵乘法总是以 float32 格式生成结果。如果您的输入不是 float32,我们建议使用 lax.dot
并将 preferred_element_type
设置为 jnp.float32
。
使用 lax.dot_general
时,可以将矩阵乘法操作数最后两个维度的转置融合到操作中,这可以提高整体内核性能。
精度控制#
Pallas TPU 下层实现会考虑 jax.default_matmul_precision
。为了获得最佳性能(和最低精度),请使用 bfloat16
。如果您关心数值精度,您可能希望将精度设置为 float32
。
警告
即使您将 32 位操作数传递给矩阵乘法,除非请求 float32
精度,否则它们将被四舍五入为 bfloat16
。
转置#
如果值至少有 4 个维度,则除最后两个轴外的所有轴的任意转置都是免费的。否则,只实现最后两个轴的转置。请注意,最后两个维度的一些转置可以融合到矩阵乘法中。
访问内存#
可以读取或更新引用的任意切片,但受实现限制。目前,对 32 位宽的输入没有限制,但对于较窄的类型只支持某些切片模式。最后两个维度分别与 8 和 128 的倍数对齐且长度是 8 和 128 的倍数的读写总是受支持的。
对向量内存的读写通常发生在形状为 (8, 128)
的瓦片上。因此,当对至少有两个维度的引用进行读写时,当内存访问的基本偏移量具有可被瓦片化整除的索引,并且读取区域的大小是瓦片大小的倍数时,可以实现最佳性能。
逐元素操作#
支持许多逐元素操作。值得注意的是,硬件通常只支持使用 32 位类型进行逐元素计算。加载使用低精度类型的操作数时,通常应在应用逐元素操作之前将其向上转换为 32 位类型。
值得注意的是,它们的成本可能显著不同。因此,我们概述了支持的操作的三种类别:廉价(🟢)、中等(🌕)和昂贵(🔴)。
操作 |
成本 |
---|---|
|
🟢 |
|
🟢 |
|
🟢 |
|
🌕 |
|
🟢 |
|
🟢 |
|
🟢 |
|
🟢 |
|
🟢 |
比较 ( |
🟢 |
类型转换 ( |
🟢 |
|
🌕 |
|
🌕 |
|
🌕 |
|
🔴 |
|
🔴 |
许多 JAX 函数是根据其他 JAX 原语实现的,因此此列表可能不全面。例如,jax.nn.relu
是根据比较和 jnp.where
实现的,它们在 Pallas 内核中也能工作。
数组构造器#
所有常量数组构造器都受支持(jnp.ones
、jnp.zeros
、jnp.full
)。
规约#
支持 sum
、max
、min
(用于浮点值)规约,以及用于布尔值的 any
和 all
。不支持整数规约。
对数组最后维度的规约通常最慢。对倒数第二个维度的规约更快,但仍比对前导维度的规约慢。
广播#
广播的性能特性与规约非常相似。沿除了最后两个维度之外的所有维度的广播总是受支持且免费的。沿倒数第二个维度的广播较慢,而沿最后一个维度的广播最慢。
重塑#
和往常一样,除最后两个维度外的所有维度的重塑都受支持且免费。
当重塑可以修改数组的最后两个维度时,唯一支持的两种情况是:(1)一些前导维度被展平到倒数第二个维度,或者(2)它添加了一个刚刚通过规约删除的维度。
随机数生成#
Pallas 支持 jax.random
模块中最常用的函数,例如 uniform
、normal
和 bernoulli
。密钥应为 threefry2x32
密钥,这是 JAX 中的默认设置。密钥可以直接传递到内核中,也可以在内核内部生成。
控制流#
TPU 后端目前对控制流的支持有限。目前支持的函数有 cond
、fori_loop
和 for_loop
。但是,循环原语目前在编译期间会完全展开,所以请尽量保持循环迭代次数合理小。
过度使用控制流可能导致低级代码生成出现显著的性能退化,建议尝试将尽可能多的计算密集型操作挤压到一个基本块中。