使用 Pallas 编写 TPU 内核#

本页重点介绍在尝试于 Google TPU 上运行 Pallas 内核时重要的细节。首先,TPU 后端仍处于实验阶段,并且仅接受 JAX NumPy 的一部分。此外,为 TPU 编写高性能代码可能需要仔细考虑硬件的固有能力。虽然许多不符合硬件特性的模式将被接受,但它们最终可能需要软件仿真,并可能降低计算速度。

警告

此功能仍应被视为实验性的,因为工作仍在进行中(特别是在改进错误消息方面)。

注意

虽然这里描述的所有功能都是实验性的,但我们对维护其正确性非常认真。因此,在尝试编写 TPU 内核时,看到“未实现”错误并不少见。但是,如果编译器接受了某个内核,它必须返回预期的结果。

如果您看到意外的输出,请将其与通过 pallas_call 传递 interpret=True 运行的内核进行比较。如果结果出现分歧,请提交错误报告

什么是 TPU?#

A TPUv4 board

TPU 是 Google 开发的硬件加速器。您可以将 TPU 视为 GPU,但专门针对机器学习工作负载。因此,它们的架构差异很大。但是,我们相信 Pallas 可以让编写 TPU 内核变得容易,即使不完全了解底层硬件。话虽如此,充分了解硬件肯定会使编写高性能内核变得更容易。

简而言之,TPU 和 GPU 的主要区别在于 TPU 是顺序机器,具有非常宽的向量寄存器(有点像 CPU!)。同时,它们允许软件在后台调度某些操作,使其与主指令流异步执行。这包括 HBM 内存访问(不能直接发出,而必须通过 DMA 子单元预取到内存层次结构的较低级别)、矩阵乘法(由 MXU 单元支持)或矩阵转置和置换(由 XLU 单元支持)等。

如果您有兴趣详细了解 TPU 架构,我们建议阅读多年来发表的一系列论文。虽然其中许多讨论了特定的 TPU 代,但其中描述的许多概念也适用于后续代。

值得注意的属性和限制#

BlockSpec 和网格迭代#

BlockSpec(请参阅 BlockSpec,即如何分块输入)在 Pallas 中的行为通常符合预期 — 内核主体的每次调用都可以访问输入的切片,并用于初始化输出的切片。

注意

并非所有块形状都受支持。在 TPU 上,仅支持秩至少为 1 的块

。此外,您的块形状的最后两个维度必须能被 8 和 128 整除,或者等于整个数组的相应维度。

Pallas TPU 内核的一个有趣方面是它们处理内存空间的方式:虽然 pallas_call 的输入通常驻留在 HBM(TPU 主内存)中,但传递给内核主体的引用将指向内存层次结构较低层(VMEM 或 SMEM)中的缓冲区。这使得内核主体能够以非常高的速度读写它们,而与 HBM(具有非常高的延迟)的所有通信都由编译器处理,并与计算重叠。

此外,与 GPU 相比,TPU 实际上是高度顺序化的机器。因此,网格通常不是并行处理,而是顺序处理,按字典顺序(尽管有关例外情况,请参阅 多核 TPU 配置 部分)。这解锁了一些有趣的功能

  • 当两个(按字典顺序)连续的网格索引使用同一输入切片时,由于数据已可用,因此会跳过第二次迭代的 HBM 传输。

  • 内核主体的多次调用可以写入同一输出切片,而没有任何竞争条件风险。但是,我们要求所有写入特定切片的调用都是连续的。

输出上的“连续”限制通常意味着网格维度的某个前缀始终改变调用需要访问的输出切片,而输出窗口在剩余后缀上保持不变。

例如,在为矩阵乘法实现 Pallas TPU 内核时,通常会使用三维网格:前两个维度对应于沿第二个操作数的第一个轴和第二个轴进行切片。第三个和*最后一个*网格轴将对规约维度进行分块。对应于规约维度的网格轴必须是最后一个,因为输出窗口沿该轴不变化。然后可以将输出引用用作部分结果的累加器。

注意

VMEM 对于这种低级别内存层次结构相当大(16MB+),可以支持大窗口大小。而且,窗口越大,最终的硬件利用率往往越好。但是,有可能指定一个窗口大小(加上容纳溢出向量寄存器所需的空间)超过 VMEM 的大小。在这种情况下,您可能会看到一个低级编译器错误消息,抱怨内存不足。

数组布局#

数组的维度顺序在 Pallas 中是有意义的。在 JAX 程序中,jax.jit 内部中间数组的排序通常对性能没有影响,因为编译器可以自由地重新排列它们。但是,由于 Pallas 旨在公开更低级别的功能,维度顺序可能对生成的代码质量有很大影响。

TPU 在 2D 向量寄存器上执行大部分计算,这些寄存器通常对于 32 位值的大小为 8x128(截至 TPU v6)。当向量值从 VMEM 加载到寄存器时(例如,x = x_ref[...]),数组的最后两个维度将被分块到寄存器中。Pallas 只会考虑将中间数组的最后两个维度映射到 8x128 向量寄存器维度(分别为子通道和通道)。

这是一个 12x320 数组如何使用 6 个 8x128 块进行分块的图形示例

../../_images/vector_layout_example.svg

分块布局对内核编写者有几个重要的影响

  • 数组的最后两个轴的处理方式与其他轴不同。例如,当涉及最后两个轴时,规约、重塑和转置通常成本更高。一些涉及最后两个维度的重塑不受支持,并将导致编译器错误,但对于其他维度,它们是“免费”的,并在编译时执行。

  • 虽然有时不可避免,但最后两个轴上的单例维度通常是浪费的,因为它们将占用整个块维度的 1 个元素。消耗过多的寄存器也可能导致寄存器溢出到 VMEM,从而降低内核性能。

  • 与上述观点相关的是,所有向量计算都被填充到块大小。添加两个 1x1 数组的成本与添加两个 8x128 数组的成本相同,而添加两个 8x128x1x1 数组的成本将是添加两个 8x128 数组的 1024 倍,因为 8x128x1x1 数组将被填充到 8x128x8x128。

多核 TPU 配置#

在较新的 TPU 代中,芯片上的两个核心通常被抽象为一个设备。为了利用多个核心,Pallas 必须打破顺序网格执行保证,并需要将其中一个网格轴并行化到核心上。这是一个选择加入的过程。为了允许这样做,pallas_call 需要一个额外的参数,名为 dimension_semantics

pallas_call(
    ...,
    compiler_params=pltpu.CompilerParams(
        dimension_semantics=["parallel", "parallel", "arbitrary"]
    ),
  )

该参数是一个列表,其条目数量与网格中的轴数量相同。只有 parallel 维度可以跨核心分区。经验法则是,维度是并行的,除非输出窗口不发生变化。因此,dimension_semantics 始终是 parallel 轴的数量,后跟 arbitrary 轴的数量。

虽然将内核分区到 2 核 TPU 设备通常会导致 2 倍的加速,但实际上可能显著小于此。如果不同的主体实例成本差异很大,尤其如此。如果所有昂贵的步骤都映射到一个核心,而所有廉价步骤都分配给另一个核心,那么第二个核心将闲置,直到第一个核心完成其任务。

Pallas TPU 通常倾向于划分大小是 TPU 核心数量倍数的轴,并倾向于划分前导网格轴。

将操作数放置在 SMEM 中#

TPU 上的大部分计算将在向量单元上进行。然而,在许多情况下,执行一些标量操作是有用的,例如,执行控制流。因此,TPU 配备了一个独立的标量单元,以及连接到它的独立标量内存(SMEM)。经验法则是,用于执行控制流决策的任何数据都应放置在 SMEM 中。

SMEM 是一种低延迟内存,支持随机访问,但每次只能用一条指令读写 32 位值(与 VMEM 事务的 4KBi 粒度相比非常小,但由于缺乏对齐要求而更加灵活!)。

当实现不以规律模式访问输入块的内核(例如,编写块稀疏内核)时,标量内存也很有用。在 Pallas 中,这可以通过将 pallas_callgrid 参数替换为具有非零 num_scalar_prefetch 参数的 PrefetchScalarGridSpecgrid_spec 来实现。如果 num_scalar_prefetchn,那么 pallas_call 的前 n 个参数将放置在 SMEM 中。对于这些参数,不应指定 BlockSpec。但是,所有后续参数的 BlockSpec 将接收网格索引以及前导操作数的 SMEM 引用。

有关使用此功能的示例,请参阅 标量预取和块稀疏计算

支持的数据类型#

目前 Pallas TPU 支持以下数据类型

  • jnp.float32

  • jnp.bfloat16

  • jnp.int*(所有精度,除了 jnp.int4

  • jnp.uint*(所有精度)

  • jnp.bool_

计算放置#

所有标量(即 0D)数组将存储在标量寄存器中,并且对它们的运算将在标量核心上执行。所有其他运算(即使是单元素但 1D+ 数组)将在向量核心上执行。

支持的操作#

矩阵乘法#

矩阵乘法始终产生 float32 格式的结果。如果您的输入不是 float32,我们建议使用 lax.dot 并将 preferred_element_type 设置为 jnp.float32

使用 lax.dot_general 时,可以将矩阵乘法操作数最后两个维度的转置融合到操作中,这可以提高整体内核性能。

精度控制#

Pallas TPU 降低(lowering)意识到 jax.default_matmul_precision。为获得最佳性能(和最低精度),请使用 bfloat16。如果您关心数值精度,则可能需要将精度设置为 float32

警告

即使您将 32 位操作数传递给矩阵乘法,除非请求 float32 精度,否则它们将被舍入到 bfloat16

转置#

如果值至少有 4 个维度,则对除最后两个轴之外的所有轴进行的任意转置是免费的。否则,仅实现最后两个轴的转置。请注意,最后两个维度的某些转置可以融合到矩阵乘法中。

访问内存#

引用可以进行任意切片读写,但需遵守实现限制。目前,对 32 位宽度的输入没有限制,但对较窄类型的切片模式仅支持部分。读写对最后两个维度的 8 和 128 的倍数对齐且长度为相应倍数的读写始终受支持。

向量内存的读写通常发生在形状为 (8, 128) 的块上。因此,在读写至少有两个维度的引用时,当内存访问的基本偏移量的索引是分块的倍数,并且读取区域的大小是块大小的倍数时,可以获得最佳性能。

逐元素操作#

支持许多逐元素操作。值得注意的是,硬件通常仅支持使用 32 位类型的逐元素计算。加载使用较低精度类型的操作数时,通常应在应用逐元素操作之前将其提升到 32 位类型。

值得注意的是,它们的成本可能*有很大差异*。因此,我们概述了三种支持的操作类别:便宜(🟢)、中等(🌕)和昂贵(🔴)。

操作

成本

jnp.add+

🟢

jnp.sub-

🟢

jnp.mul*

🟢

/, //, %

🌕

jnp.maxjnp.min

🟢

jnp.where(选择)

🟢

jnp.abs

🟢

|, ^, &, ~

🟢

<<, >>

🟢

比较(==,...)

🟢

类型转换(.astype

🟢

jnp.exp

🌕

jnp.tanh

🌕

jnp.pow

🌕

jnp.sin

🔴

jnp.cos

🔴

许多 JAX 函数是根据其他 JAX 原始函数实现的,因此此列表可能不完整。例如,jax.nn.relu 是根据比较实现的,并且 jnp.where 也可以在 Pallas 内核中工作。

数组构造器#

所有常量数组构造器都受支持(jnp.onesjnp.zerosjnp.full)。

规约#

支持(浮点值)的 summaxmin 规约,以及(布尔值)的 anyall。不支持整数规约。

沿最后一个数组维度的规约通常最慢。沿倒数第二个维度的规约速度更快,但仍比沿前导维度的规约慢。

广播#

广播的性能特征与规约非常相似。沿除最后两个维度之外的所有维度的广播始终受支持且免费。沿倒数第二个维度的广播速度较慢,而沿最后一个维度的广播最慢。

重塑#

与往常一样,除最后两个维度之外的所有维度的重塑都受支持且免费。

当重塑可以修改数组的最后两个维度时,仅支持两种情况:(1)将某些前导维度展平到倒数第二个维度,或者(2)它添加了一个刚刚被规约移除的维度。

随机数生成#

Pallas 支持 jax.random 模块中最常用的函数,例如 uniformnormalbernoulli。密钥应该是 threefry2x32 密钥,这是 JAX 中的默认设置。密钥可以直接传递到内核,也可以在内核内部生成。

控制流#

目前,TPU 后端对控制流的支持有限。当前支持的函数是 condfori_loopfor_loop。但是,循环原语目前在编译期间完全展开,因此请尝试将循环计数保持在合理的小范围内。

过度使用控制流可能导致低级代码生成出现重大回归,建议尝试将尽可能多的计算密集型操作挤入单个基本块。