矩阵乘法#

在本指南中,我们将使用 Pallas 编写一个矩阵乘法例程。我们还将介绍如何思考 TPU 上 matmul 的性能,以及如何模板化 matmul 内核以融合操作。

#@title Imports
import functools
from typing import Callable

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax import random
import jax.numpy as jnp
import numpy as np

背景#

矩阵乘法是现代深度学习和语言建模核心的基本线性代数运算。我们希望使用像 TPU 和 GPU 这样的专用加速器尽可能地提高 matmul 的速度,它们都具有用于快速矩阵乘法的专用单元。

为了有效地利用 TPU 进行矩阵乘法,我们需要涵盖一些背景概念:分块矩阵乘法、平铺和流水线。

分块矩阵乘法#

假设我们要实现 matmul(x, y),它通常将一个 (m, k) 数组与一个 (k, n) 数组相乘,但有一个转折。我们只允许使用原语 matmul_small,它将小矩阵相乘(例如 m, k, n <= 256)。我们该怎么做呢?

矩阵乘法的一个很好的特性是,输出的每个块都可以表示为输入行块和列块的几个较小矩阵乘法的总和。形式上,如果我们有输入数组 \(x \in \mathbb{R}^{m \times k}\)\(y \in \mathbb{R}^{k \times n}\) 以及输出 \(z \in \mathbb{R}^{m \times n}\),我们将它们沿大小为 \(b_m, b_k, b_n\) 的维度分解为块。

例如,\(x\) 将被分解为

\[\begin{split} \begin{bmatrix} x_{0, 0} & \cdots & x_{0, i_k} \\ x_{1, 0} & \cdots & x_{1, i_k} \\ \vdots & \ddots & \vdots \\ x_{i_m, 0} & \cdots & x_{i_m, i_k} \\ \end{bmatrix} \end{split}\]

其中 \(x_{ik} \in \mathbb{R}^{b_m \times b_k}\)。(我们可以类似地分解 \(y\)\(z\)。)

对于特定的输出块 \(z_{ij}\),我们可以将其计算为

\[ z_{ij} = \sum_k x_{ik} y_{kj} \]

因此,每个输出块 \(z_{ij}\) 是几个较小的块矩阵乘法 \(x_{ik} y_{kj}\) 的总和。以下是我们如何在 NumPy 中实现此算法

def matmul_small(x: np.ndarray, y: np.ndarray) -> np.ndarray:
  m, k, n = x.shape[0], x.shape[1], y.shape[0]
  assert m <= 256
  assert k <= 256
  assert n <= 256
  return np.matmul(x, y)

def block_matmul(
    x: np.ndarray,
    y: np.ndarray,
    *,
    bm: int = 256,
    bk: int = 256,
    bn: int = 256,
) -> np.ndarray:
  m, k = x.shape
  _, n = y.shape

  z = np.zeros((m, n), dtype=x.dtype)
  for m_i in range(m // bm):
    for n_i in range(n // bn):
      for k_i in range(k // bk):
        m_slice = slice(m_i * bm, (m_i + 1) * bm)
        k_slice = slice(k_i * bk, (k_i + 1) * bk)
        n_slice = slice(n_i * bn, (n_i + 1) * bn)
        x_block = x[m_slice, k_slice]
        y_block = y[k_slice, n_slice]
        z[m_slice, n_slice] += matmul_small(x_block, y_block)
  return z

我们的 block_matmul 函数现在应该可以处理大于 256 的输入(尽管我们假设我们的输入维度可以均匀地除以 256)。

m, k, n = 4096, 4096, 4096
x = np.random.uniform(size=(m, k)).astype(np.float32)
y = np.random.uniform(size=(k, n)).astype(np.float32)
np.testing.assert_allclose(x @ y, block_matmul(x, y), atol=1e-6, rtol=1e-6)

block_matmul 通过观察到大小为 (bm, bn) 的每个输出块可以通过累积几个 (bm, bk) x (bk, bn) 大小的矩阵乘法来计算,从而将矩阵乘法分解为许多较小的矩阵乘法。

TPU 和 GPU 执行 matmul 的方式就像这样!它们原生支持类似于 matmul_small 的小矩阵乘法,因此为了在执行更大的矩阵乘法时利用此硬件,我们将应用 block_matmul 分解。

平铺和流水线#

之前的指南中,我们介绍了 Pallas 中的平铺计算和流水线工作原理。为了确保我们的计算单元始终工作,并且永远不会因内存传输而停顿,我们将内核的下一次迭代的内存传输与当前迭代重叠。

在 Pallas 中,我们通过 BlockSpecgrid 指定这一点。请注意,我们在分块矩阵乘法算法中已经有一个嵌套的 for 循环。我们可以通过 grid 在 Pallas 中指定这一点。分块矩阵乘法中的切片也可以通过 BlockSpec 指定。

您的第一个矩阵乘法内核#

将所有内容放在一起,这是一个分块矩阵乘法内核的实现,它将内存传输与计算流水线化。我们创建一个 3 维网格,对应于 NumPy 代码中的 3 层嵌套循环。请注意,虽然 MXU 只能执行小块矩阵乘法,但 Pallas 将自动获取更大的块,并自动将它们平铺到 MXU 上。

网格的最后一个维度对应于矩阵乘法的收缩维度,并且是一个归约维度,因此我们需要确保初始化累加器。

def matmul_kernel(x_ref, y_ref, z_ref):
  @pl.when(pl.program_id(2) == 0)
  def _():
    z_ref[...] = jnp.zeros_like(z_ref)

  z_ref[...] += x_ref[...] @ y_ref[...]

def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
):
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      matmul_kernel,
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      in_specs=[pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
                pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],
      out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
      grid=(m // bm, n // bn, k // bk),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.float32)
y = random.normal(k2, (k, n), dtype=jnp.float32)
np.testing.assert_array_equal(x @ y, matmul(x, y))

矩阵乘法性能#

让我们思考如何分析矩阵乘法性能。当我们思考 matmul 性能时,我们通常关心两件事:浮点运算总数 (FLOPs) 和内存带宽使用量。从关于 TPU 和流水线的指南中,我们看到为了使用 TPU(以及通用 ML 加速器)上的高效计算单元,我们需要将输入从 HBM 复制到 VMEM,更靠近计算单元。复制到 HBM 和从 HBM 复制都需要时间,而高效的内核希望将其大部分时间实际用于计算,而不是等待这些传输。内存带宽衡量此数据传输的速率。

快速说明:在本指南中,我们将讨论浮点运算,但想要区分 FLOP 与 FLOP/s。当我们说 “FLOP” 时,我们指的是 “浮点运算”,即运算的数量。当我们说 “FLOP/s” 时,我们指的是 “每秒浮点运算”,即执行浮点运算的速率

(m, k) x (k, n) 矩阵乘法中的 FLOP 数量(近似)为 2 * m * k * n。(技术上是 n * m * (2k - 1),但对于足够大的 k,我们的近似值就足够了。)

矩阵乘法的最小内存带宽使用量(假设 float32)是输入的总大小(复制到 VMEM)加上输出的大小(复制到 HBM)。因此,最小带宽使用量为 (m * k + k * n + m * n) * 4 bytes/float32。如果我们多次重新读取输入,内存使用量可能会更大,这种情况经常发生。

一个观察结果是,matmul FLOP 的数量在其输入中是立方的,而最小带宽使用量在其输入中是二次的。直观地说,这意味着 FLOP 的增长速度快于带宽使用量,这意味着我们的 matmul 越大,我们相对于复制的计算就越多。

def matmul_flops(m: int, k: int, n: int):
  return 2 * m * k * n

def matmul_membw(m: int, k: int, n: int, dtype: jnp.dtype):
  return (m * k + k * n + m * n) * np.dtype(dtype).itemsize

print(matmul_flops(1024, 1024, 1024))
print(matmul_membw(1024, 1024, 1024, jnp.float32))
2147483648
12582912

现在我们可以计算矩阵乘法的 FLOP 总数和(最小)内存带宽使用量,让我们看看真正的 TPU 可以处理什么。

此 notebook 在 TPU v5e 芯片上运行,因此我们将使用 v5e 数字(如果您正在运行此 notebook,您的数字可能会有所不同)。TPU v5e 具有197 TFLOP/s 的 bf16/f32 计算能力和 819 GB/s 的内存带宽。通过查看这些数字的比率(称为算术强度),我们可以获得一个界限,即在我们变为 IO 绑定之前,这个 “FLOP/内存带宽使用量” 比率可以变得多低(在 TPU v5e 上约为 240 FLOP/字节)。

v5e_flops = 197e12
v5e_membw = 819e9
v5e_op_intensity = v5e_flops / v5e_membw  # ~240.5

粗略地说,这些数字告诉我们,matmul 的 FLOP 应花费 2 * m * k * n / (197 TFLOP/s) 秒,而复制到/从 VMEM 应花费 (m*k + k*n + m*n) * 4 bytes / 819GB/s 秒。

def matmul_flops_intensity(m: int, k: int, n: int, dtype: jnp.dtype):
  flops = matmul_flops(m, k, n)
  membw = matmul_membw(m, k, n, dtype)
  return flops / membw

这个基本计算大致告诉我们,我们将能够多有效地使用 MXU。如果我们的 matmul 运算强度低于我们的芯片的能力,那么我们的计算将是内存绑定,即我们的计算单元将在等待值传输时空闲。如果 matmul 强度高于芯片的能力,那么我们将是计算绑定

由于 matmul FLOP 在其输入大小中是立方的,而内存带宽使用量是二次的,因此我们预计随着我们变得越来越大,我们将变为计算绑定,但是这个交叉点非常重要!假设我们正在进行 (1024, 1024) x (1024, 1024) float32 矩阵乘法。

print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.float32)} flops/byte")
170.66666666666666 flops/byte

我们的 matmul flops 强度低于我们的芯片的能力。这不好!对于这种类型的矩阵乘法,我们很可能受到内存限制。但是,如果我们的输入和输出更大呢?在某些时候,当我们的 matmul 变得足够大时,我们将从内存绑定变为计算绑定。例如,如果我们有一个 matmul,其中 m = k = n,当 2m**3 / 12m**2 > 240 或当 m = k = n > 1440 时,我们将交叉(在 TPU v5e 上)。

bfloat16 矩阵乘法#

为了使矩阵乘法更容易在 TPU 上进行计算绑定,我们还可以为输入和输出使用更小的数据类型。我们之前的示例使用了 float32 输入和输出,但 TPU v5e 也支持 bfloat16 数据类型(一种 16 位浮点格式,也称为 bf16)用于矩阵乘法。在 TPU v5e 上,我们将具有相同的 FLOP/s,但将内存带宽使用量减半。这使得较小的矩阵更容易进行计算绑定。让我们看看 1024 x 1024 x 1024 bf16 矩阵乘法的强度是多少

print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.bfloat16)} flops/byte")
341.3333333333333 flops/byte

我们现在有一个计算绑定的 matmul!

让我们将 bf16 支持添加到我们的矩阵乘法内核。

原生 MXU bf16 matmul 例程接受两个输入 bf16 矩阵,并在 f32 中累积它。我们将通过将 preferred_element_type=jnp.float32 传递到 jnp.matmul 来触发此例程。我们还需要一个 f32 中的累加器 Ref。然后,我们将输出向下转换为 bf16,然后再将其写回 HBM。这样,我们不会丢失任何精度,不会进行任何额外的转换,并且仍然保留 bf16 内存带宽节省。

请注意,现在分配暂存空间的唯一方法是通过 pltpu.PrefetchScalarGridSpec。现在不要担心它具体做什么 – 您现在需要知道的所有信息是,它允许您在 VMEM 中分配暂存空间。

def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  acc_ref[...] += jnp.dot(
      x_ref[...], y_ref[...], preferred_element_type=jnp.float32
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = acc_ref[...].astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
):
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(matmul_kernel, nsteps=k // bk),
      grid_spec=pltpu.PrefetchScalarGridSpec(
        num_scalar_prefetch=0,
        in_specs=[
            pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
            pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)),
        ],
        out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
        scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
        grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.bfloat16)
y = random.normal(k2, (k, n), dtype=jnp.bfloat16)
np.testing.assert_array_equal(x @ y, matmul(x, y))

流水线内核的性能#

我们上面关于 FLOP 与内存使用量的分析适用于粗略的规模,即当我们查看矩阵乘法的总大小时。但是,请记住,在实践中,我们正在流水线化分块矩阵乘法的执行,这意味着我们有一个循环,我们在其中使用较小的块进行矩阵乘法。

这意味着我们实际上关心的是内核的每个单独实例的 FLOP 与内存带宽使用量,而不是全局 FLOP 与内存带宽使用量。因此,块大小 bmbkbn 对于性能至关重要。即使我们拥有世界上最大的矩阵,如果我们选择非常小的 bmbkbn,我们将受到内存限制,因为每次调用内核时,我们的 FLOP 太少而无法隐藏在后台发生的内存传输。

因此,直觉应该是:要进行计算绑定,请使块尽可能大!有两个主要约束

  1. VMEM 使用量:我们的块越大,我们使用的 VMEM 就越多。使用足够大的块,我们将耗尽 VMEM。

  2. 流水线气泡:我们的块相对于矩阵大小越大,我们在流水线中的循环迭代次数就越少。这将使流水线开始和结束时的气泡大小相对于总流水线更大,并且这种开销可能很重要。

在 Pallas 中获得良好的矩阵乘法性能归结为选择良好的块大小来平衡此优化问题。在实践中,我们通常扫描大量候选块大小,分析内核,然后选择最佳块大小。

现在,让我们进行一些非常简单的计时实验。我们将使用 timeit 来衡量运行每个内核所需的时间量。请注意,这是内核实际运行时的上限,因为我们正在使用 timeit 测量 Python 调度和其他开销。我们将计算通过这种方式获得的 FLOP/s 量,并计算我们获得的利用率与芯片提供的利用率的百分比,并且我们将使用一些合理的块大小来验证我们的直觉。

import timeit

def benchmark(f, ntrials: int = 100):
  def run(*args, **kwargs):
    # Compile function first
    jax.block_until_ready(f(*args, **kwargs))
    # Time function
    result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
                           number=ntrials)
    time = result / ntrials
    # print(f"Time: {time}")
    return time
  return run

def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func):
  x = jnp.ones((m, k), dtype=dtype)
  y = jnp.ones((k, n), dtype=dtype)
  time = benchmark(mm_func)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()

print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00029766598949208854
Matmul FLOP/s:  7214407167121.377
FLOP/s utilization: 3.6621%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.011771515250438824
Matmul FLOP/s:  11675553278230.387
FLOP/s utilization: 5.9267%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09183577066054567
Matmul FLOP/s:  11972585626140.668
FLOP/s utilization: 6.0775%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012708659982308746
Matmul FLOP/s:  16897797651282.135
FLOP/s utilization: 8.5776%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.00088908776990138
Matmul FLOP/s:  154584235803001.88
FLOP/s utilization: 78.4692%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006099433819763363
Matmul FLOP/s:  180264539343531.62
FLOP/s utilization: 91.5048%

更大的块大小有很大帮助!我们在更大的 matmul 中获得了相当不错的利用率 (80-90%),但最小的 matmul 似乎很难获得良好的性能。

让我们将 Pallas 的矩阵乘法与 XLA 的矩阵乘法进行比较。我们不期望 Pallas 比 XLA 表现更好,因为 XLA 非常擅长生成矩阵乘法,但我们希望能够接近 XLA 的性能。通过更仔细的块大小调整(留作未来工作),我们也可以达到 XLA 的性能水平。

print("================ XLA matmul ===================")
mm = jnp.matmul
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================ XLA matmul ===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00011943008983507753
Matmul FLOP/s:  17981093801113.996
FLOP/s utilization: 9.1275%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.0008272899803705514
Matmul FLOP/s:  166131533963991.34
FLOP/s utilization: 84.3307%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006047147869830951
Matmul FLOP/s:  181823175395037.44
FLOP/s utilization: 92.2960%

Pallas 通过一些基本的调整,就非常接近 XLA 的性能数字了!通过尝试更多块大小,我们应该有望完全消除差距。

模板化矩阵乘法#

现在我们有了一个基本的矩阵乘法内核,我们可以尝试将操作融合到其中。

融合右手侧转置#

常见的首要操作是融合转置。这是什么意思呢?假设我们想要计算 x @ y.T 而不是 x @ y。 朴素的方法可能是先计算 y.T,然后将其传递到我们高效的矩阵乘法内核中。然而,y.T 操作本身并非免费的 —— 它涉及到复制 O(n^2) 的数据。理想情况下,我们可以在执行矩阵乘法同时计算转置,只需一个内核即可完成,即将其与矩阵乘法“融合”。

加速器通常支持融合 RHS 转置的原生矩阵乘法例程。例如,TPU v5e 的 MXU 允许我们对小型数组执行 x @ y.T。我们可以使用 jax.lax.dot_general 调用此例程,这将比单独执行转置然后进行矩阵乘法更有效。

def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  # dot_general expects a data structure (contraction_dims, batch_dims),
  # where contraction_dims are the set of dimensions for LHS and RHS that will
  # be contracted (reduced) in the matmul; batch_dims, on the other hand, are
  # looped over. The remaining dimensions will be the input and output dimension
  # of the matmul.
  if transpose_rhs:
    dims = ((1,), (1,)), ((), ())
  else:
    dims = ((1,), (0,)), ((), ())

  acc_ref[...] += jax.lax.dot_general(
      x_ref[...], y_ref[...], dims, preferred_element_type=jnp.float32,
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = acc_ref[...].astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'transpose_rhs'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
    transpose_rhs: bool = False,
):
  if transpose_rhs:
    y = y.swapaxes(0, 1)
    y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
  else:
    y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(matmul_kernel, nsteps=k // bk, transpose_rhs=transpose_rhs),
      grid_spec=pltpu.PrefetchScalarGridSpec(
        num_scalar_prefetch=0,
        in_specs=[
            pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
            y_block_spec,
        ],
        out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
        scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
        grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)

我们在 matmul 函数内部进行转置(y = y.swapaxes(0, 1))。这是因为在 JIT 编译的 JAX 计算中,维度顺序纯粹是逻辑上的,而不是物理上的,因此重新排列维度并不意味着物理布局的差异。但是,当我们将数组传递到 pallas_call 中时,我们确实强制执行了主维度到次维度的排序约束。通过在 matmul 函数内部转置 y,我们请求 y 采用转置布局 (n, k) 而不是通常的 (k, n)。用户仍然会以(逻辑)(k, n) 维度传入数组。

注意:为了基准测试转置,我们实际上希望将 y 以物理转置布局传递到内核中,这样我们就不会测量重布局时间。在包装器函数中,我们会将其(逻辑上)转置回 (k, n),然后再将其传递到 matmul 中,因为 matmul 期望逻辑 (k, n) 维度顺序。

def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func, transpose_rhs: bool = False):
  x = jnp.ones((m, k), dtype=dtype)
  if transpose_rhs:
    y = jnp.ones((n, k), dtype=dtype)
    @jax.jit
    def _wrapper(x, y):
      y = y.swapaxes(0, 1)
      return mm_func(x, y, transpose_rhs=True)
  else:
    y = jnp.ones((k, n), dtype=dtype)
    _wrapper = mm_func
  time = benchmark(_wrapper)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()

print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.0003029372810851783
Matmul FLOP/s:  7088872126624.065
FLOP/s utilization: 3.5984%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.012017967159627005
Matmul FLOP/s:  11436123235026.848
FLOP/s utilization: 5.8051%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09500920018996112
Matmul FLOP/s:  11572685861765.383
FLOP/s utilization: 5.8745%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012131539988331496
Matmul FLOP/s:  17701657415839.363
FLOP/s utilization: 8.9856%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.0008790623804088682
Matmul FLOP/s:  156347213275211.03
FLOP/s utilization: 79.3641%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006107717020204291
Matmul FLOP/s:  180020067095253.78
FLOP/s utilization: 91.3807%

看看我们如何在额外的转置下获得相同的利用率!

融合激活函数#

融合激活函数也非常常见。这确保了我们不会在一个高效的、计算受限的矩阵乘法内核之后,紧跟着一个缓慢的、内存受限的激活函数内核。

def matmul_kernel(
    x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs, activation
):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  if transpose_rhs:
    dims = ((1,), (1,)), ((), ())
  else:
    dims = ((1,), (0,)), ((), ())

  acc_ref[...] += jax.lax.dot_general(
      x_ref[...],
      y_ref[...],
      dims,
      preferred_element_type=jnp.float32,
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = activation(acc_ref[...]).astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'activation'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
    transpose_rhs: bool = False,
    activation: Callable[[jax.Array], jax.Array] = lambda x: x,
):
  if transpose_rhs:
    y = y.swapaxes(0, 1)
    y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
  else:
    y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(
          matmul_kernel,
          nsteps=k // bk,
          transpose_rhs=transpose_rhs,
          activation=activation,
      ),
      grid_spec=pltpu.PrefetchScalarGridSpec(
          num_scalar_prefetch=0,
          in_specs=[
              pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
              y_block_spec,
          ],
          out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
          scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
          grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func, transpose_rhs: bool = False,
                   activation = lambda x: x):
  x = jnp.ones((m, k), dtype=dtype)
  if transpose_rhs:
    y = jnp.ones((n, k), dtype=dtype)
    @jax.jit
    def _wrapper(x, y):
      y = y.swapaxes(0, 1)
      return mm_func(x, y, transpose_rhs=True, activation=activation)
  else:
    y = jnp.ones((k, n), dtype=dtype)
    _wrapper = functools.partial(mm_func, activation=activation)
  time = benchmark(_wrapper)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()


activation = jax.nn.relu
print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00030103540048003196
Matmul FLOP/s:  7133658182976.541
FLOP/s utilization: 3.6211%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.011807117109419778
Matmul FLOP/s:  11640348122095.826
FLOP/s utilization: 5.9088%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09181861146935262
Matmul FLOP/s:  11974823079773.941
FLOP/s utilization: 6.0786%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012622540001757442
Matmul FLOP/s:  17013086492108.6
FLOP/s utilization: 8.6361%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.000896632740041241
Matmul FLOP/s:  153283442968721.44
FLOP/s utilization: 77.8089%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006130605939542875
Matmul FLOP/s:  179347953304919.88
FLOP/s utilization: 91.0396%

额外的融合激活函数几乎不会影响我们的利用率!

结论#

在本指南中,我们介绍了如何使用 Pallas 在 TPU 上编写高效的矩阵乘法。我们讨论了分块矩阵乘法和流水线处理,如何分析 TPU 矩阵乘法的性能,以及如何编写高效的 bf16 矩阵乘法。最后,我们讨论了模板化矩阵乘法以支持融合转置和融合激活函数。

留给读者的练习

  • 添加对输入融合的支持。有时我们希望将操作融合到矩阵乘法的输入中。尝试更多地模板化矩阵乘法以支持这一点。

  • 添加对 int8 矩阵乘法的支持。TPU v5 支持原生 int8 矩阵乘法,其 FLOPs 是 bf16 的两倍。尝试添加对其的支持,看看可能的利用率是多少。

  • matmul 函数添加反向传播支持。您可以使用 jax.custom_vjp 来实现这一点。