矩阵乘法#

在本指南中,我们将使用 Pallas 编写一个矩阵乘法例程。我们还将讨论如何考虑 TPU 上的矩阵乘法性能以及如何将矩阵乘法核模板化以融合操作。

#@title Imports
import functools
from typing import Callable

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax import random
import jax.numpy as jnp
import numpy as np

背景#

矩阵乘法是现代深度学习和语言模型核心的基础线性代数运算。我们希望通过使用 TPU 和 GPU 等专用加速器,使矩阵乘法尽可能快,这两种加速器都具有用于快速矩阵乘法的专用单元。

为了有效利用 TPU 进行矩阵乘法,我们需要涵盖几个背景概念:分块矩阵乘法、分块和平行流水线。

分块矩阵乘法#

假设我们想实现一个 matmul(x, y) 函数,它通常将一个 (m, k) 数组与一个 (k, n) 数组相乘,但有一个限制。我们只允许使用原语 matmul_small,它只能乘以小矩阵(例如 m, k, n <= 256)。我们该怎么做呢?

矩阵乘法的一个很好的性质是,输出的每个块都可以表示为输入行块和列块的几个较小矩阵乘法之和。形式上,如果我们有输入数组 \(x \in \mathbb{R}^{m \times k}\)\(y \in \mathbb{R}^{k \times n}\) 以及输出 \(z \in \mathbb{R}^{m \times n}\),我们沿维度将它们分解为大小为 \(b_m, b_k, b_n\) 的块。

例如,\(x\) 将被分解为

\[\begin{split} \begin{bmatrix} x_{0, 0} & \cdots & x_{0, i_k} \\ x_{1, 0} & \cdots & x_{1, i_k} \\ \vdots & \ddots & \vdots \\ x_{i_m, 0} & \cdots & x_{i_m, i_k} \\ \end{bmatrix} \end{split}\]

其中 \(x_{ik} \in \mathbb{R}^{b_m \times b_k}\)。(我们可以类似地分解 \(y\)\(z\)。)

对于特定的输出块 \(z_{ij}\),我们可以将其计算为

\[ z_{ij} = \sum_k x_{ik} y_{kj} \]

因此,每个输出块 \(z_{ij}\) 是几个较小的分块矩阵乘法 \(x_{ik} y_{kj}\) 的和。以下是我们在 NumPy 中实现此算法的方式

def matmul_small(x: np.ndarray, y: np.ndarray) -> np.ndarray:
  m, k, n = x.shape[0], x.shape[1], y.shape[0]
  assert m <= 256
  assert k <= 256
  assert n <= 256
  return np.matmul(x, y)

def block_matmul(
    x: np.ndarray,
    y: np.ndarray,
    *,
    bm: int = 256,
    bk: int = 256,
    bn: int = 256,
) -> np.ndarray:
  m, k = x.shape
  _, n = y.shape

  z = np.zeros((m, n), dtype=x.dtype)
  for m_i in range(m // bm):
    for n_i in range(n // bn):
      for k_i in range(k // bk):
        m_slice = slice(m_i * bm, (m_i + 1) * bm)
        k_slice = slice(k_i * bk, (k_i + 1) * bk)
        n_slice = slice(n_i * bn, (n_i + 1) * bn)
        x_block = x[m_slice, k_slice]
        y_block = y[k_slice, n_slice]
        z[m_slice, n_slice] += matmul_small(x_block, y_block)
  return z

我们的 block_matmul 函数现在应该可以处理大于 256 的输入(尽管我们假设我们的输入维度可以被 256 整除)。

m, k, n = 4096, 4096, 4096
x = np.random.uniform(size=(m, k)).astype(np.float32)
y = np.random.uniform(size=(k, n)).astype(np.float32)
np.testing.assert_allclose(x @ y, block_matmul(x, y), atol=1e-6, rtol=1e-6)

block_matmul 将矩阵乘法分解为许多较小的乘法,它观察到每个大小为 (bm, bn) 的输出块可以通过累加几个 (bm, bk) x (bk, bn) 大小的矩阵乘法来计算。

TPU 和 GPU 就是这样进行矩阵乘法的!它们原生支持类似于 matmul_small 的小矩阵乘法,因此为了在进行更大的矩阵乘法时利用此硬件,我们将应用 block_matmul 分解。

分块与流水线#

上一个指南中,我们介绍了 Pallas 中的计算分块和流水线如何工作。为了确保我们的计算单元始终工作并且不会被内存传输阻塞,我们将下一个内核迭代的内存传输与当前迭代的内存传输重叠。

在 Pallas 中,我们通过 BlockSpecgrid 来指定这一点。请注意,我们已经在分块矩阵乘法算法中有一个嵌套的 for 循环。我们可以在 Pallas 中通过 grid 指定它。分块矩阵乘法中的切片也可以通过 BlockSpec 指定。

你的第一个矩阵乘法核#

总而言之,以下是一个分块矩阵乘法核的实现,它将内存传输与计算并行。我们创建了一个 3 维网格,对应于 NumPy 代码中的 3 个嵌套循环。请注意,虽然 MXU 只能乘以小块,但 Pallas 会自动获取更大的块并自动将它们分块到 MXU 上。

网格的最后一个维度对应于矩阵乘法的收缩维度,是一个归约维度,因此我们需要确保初始化累加器。

def matmul_kernel(x_ref, y_ref, z_ref):
  @pl.when(pl.program_id(2) == 0)
  def _():
    z_ref[...] = jnp.zeros_like(z_ref)

  z_ref[...] += x_ref[...] @ y_ref[...]

def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
):
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      matmul_kernel,
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      in_specs=[pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
                pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],
      out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
      grid=(m // bm, n // bn, k // bk),
      compiler_params=pltpu.CompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.float32)
y = random.normal(k2, (k, n), dtype=jnp.float32)
np.testing.assert_array_equal(x @ y, matmul(x, y))

矩阵乘法性能#

让我们思考如何分析矩阵乘法性能。当我们考虑矩阵乘法性能时,我们通常关心两件事:浮点运算 (FLOPs) 的总数和内存带宽使用量。从TPU 和流水线指南中,我们看到为了使用 TPU(以及一般 ML 加速器)上高效的计算单元,我们需要将输入从 HBM 复制到 VMEM,更接近计算单元。这种往返 HBM 的复制需要时间,高效的内核应该将其大部分时间用于实际计算,而不是等待这些传输。内存带宽衡量数据传输速率。

快速说明:在本指南中,我们将讨论浮点运算,但要区分 FLOPs 和 FLOP/s。当我们说“FLOPs”时,我们指的是“浮点运算”,即操作的数量。当我们说“FLOP/s”时,我们指的是“每秒浮点运算”,即执行浮点运算的速率

(m, k) x (k, n) 矩阵乘法中的 FLOPs 数量(大约)为 2 * m * k * n。(严格来说是 n * m * (2k - 1),但对于足够大的 k,我们的近似值就足够了。)

矩阵乘法(假设 float32)的最小内存带宽使用量是输入总大小(复制到 VMEM)加上输出大小(复制到 HBM)。因此,最小带宽使用量为 (m * k + k * n + m * n) * 4 bytes/float32。如果多次重新读取输入,内存使用量可能会更大,这通常是情况。

一个观察是,矩阵乘法的 FLOPs 数量与其输入的三次方成正比,而最小带宽使用量与其输入的二次方成正比。直观地说,这意味着 FLOPs 比带宽使用量增长得更快,这意味着我们的矩阵乘法越大,我们相对于复制的计算量就越多。

def matmul_flops(m: int, k: int, n: int):
  return 2 * m * k * n

def matmul_membw(m: int, k: int, n: int, dtype: jnp.dtype):
  return (m * k + k * n + m * n) * np.dtype(dtype).itemsize

print(matmul_flops(1024, 1024, 1024))
print(matmul_membw(1024, 1024, 1024, jnp.float32))
2147483648
12582912

现在我们已经可以计算矩阵乘法的 FLOPs 总数和(最小)内存带宽使用量,让我们看看真实的 TPU 可以处理什么。

此笔记本是在 TPU v5e 芯片上运行的,因此我们将使用 v5e 的数字(如果您正在运行此笔记本,您的数字可能有所不同)。TPU v5e 具有 197 TFLOP/s 的 bf16/f32 计算能力和 819 GB/s 的内存带宽。通过查看这些数字的比率(称为算术强度),我们可以得到“FLOPs / 内存带宽使用量”比率在变得 IO 绑定之前可以达到的下限(TPU v5e 上约为 240 FLOPs/byte)。

v5e_flops = 197e12
v5e_membw = 819e9
v5e_op_intensity = v5e_flops / v5e_membw  # ~240.5

大致来说,这些数字告诉我们矩阵乘法的 FLOPs 应该花费 2 * m * k * n / (197 TFLOP/s) 秒,往返 VMEM 的复制应该花费 (m*k + k*n + m*n) * 4 bytes / 819GB/s 秒。

def matmul_flops_intensity(m: int, k: int, n: int, dtype: jnp.dtype):
  flops = matmul_flops(m, k, n)
  membw = matmul_membw(m, k, n, dtype)
  return flops / membw

这种基本计算大致告诉我们能够多有效地使用 MXU。如果我们的矩阵乘法运算强度低于芯片所能承受的,那么我们的计算将是内存受限的,即我们的计算单元将在等待值传输时处于空闲状态。如果矩阵乘法强度高于芯片所能承受的,那么我们将是计算受限的。

由于矩阵乘法的 FLOPs 随输入大小的立方增长,而内存带宽使用量随输入大小的平方增长,我们预计随着矩阵乘法变得越来越大,我们将变得计算受限,但这个交叉点非常重要!假设我们正在进行 (1024, 1024) x (1024, 1024) 的 float32 矩阵乘法。

print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.float32)} flops/byte")
170.66666666666666 flops/byte

我们的矩阵乘法 FLOPs 强度低于芯片所能承受的。这不好!我们很可能在这种类型的矩阵乘法中受到内存限制。但是,如果我们的输入和输出更大呢?在某个点,当我们的矩阵乘法足够大时,我们将从内存限制转变为计算限制。例如,如果我们有一个矩阵乘法,其中 m = k = n,我们将在 2m**3 / 12m**2 > 240m = k = n > 1440 时(在 TPU v5e 上)发生交叉。

bfloat16 矩阵乘法#

为了更容易地在 TPU 上使矩阵乘法受计算限制,我们还可以为输入和输出使用较小的数据类型。我们之前的示例使用了 float32 输入和输出,但 TPU v5e 也支持 bfloat16 数据类型(一种 16 位浮点格式,也称为 bf16)进行矩阵乘法。在 TPU v5e 上,我们将具有相同的 FLOP/s,但将内存带宽使用量减半。这使得对于较小的矩阵来说,受计算限制变得更容易。让我们看看 1024 x 1024 x 1024 bf16 矩阵乘法的强度是多少

print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.bfloat16)} flops/byte")
341.3333333333333 flops/byte

我们现在有一个受计算限制的矩阵乘法!

让我们为我们的矩阵乘法核添加 bf16 支持。

原生 MXU bf16 矩阵乘法例程接受两个 bf16 输入矩阵并在 f32 中累积。我们将通过将 preferred_element_type=jnp.float32 传递给 jnp.matmul 来触发此例程。我们还需要一个 f32 类型的累加器 Ref。然后,我们将在将输出写回 HBM 之前将其向下转换为 bf16。这样,我们不会丢失任何精度,不会进行任何额外的类型转换,并且仍然保留 bf16 的内存带宽节省。

请注意,目前分配暂存空间的唯一方法是使用 pltpu.PrefetchScalarGridSpec。现在不用担心它具体做了什么——您目前只需要知道它允许您在 VMEM 中分配暂存空间。

def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  acc_ref[...] += jnp.dot(
      x_ref[...], y_ref[...], preferred_element_type=jnp.float32
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = acc_ref[...].astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
):
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(matmul_kernel, nsteps=k // bk),
      grid_spec=pltpu.PrefetchScalarGridSpec(
        num_scalar_prefetch=0,
        in_specs=[
            pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
            pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)),
        ],
        out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
        scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
        grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.CompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.bfloat16)
y = random.normal(k2, (k, n), dtype=jnp.bfloat16)
np.testing.assert_array_equal(x @ y, matmul(x, y))

流水线核的性能#

我们上面关于 FLOPs 与内存使用量的分析适用于粗粒度,即当我们查看整个矩阵乘法的大小时。但是,请记住,在实践中,我们正在流水线化分块矩阵乘法的执行,这意味着我们有一个循环,在其中我们正在进行较小块的矩阵乘法。

这意味着我们实际上关心的是内核的每个单独实例的 FLOPs 与内存带宽使用量,而不是全局的 FLOPs 与内存带宽使用量。

此外,在分块矩阵乘法运算时,相同的值可能会从内存中多次读取。具体来说,内核第一个操作数的内存带宽是 (bm * bk),这需要乘以网格维度,即 (bm * bk) * m // bm * n // bn * k // bk = m * k * n // bn。第二个操作数也是如此,总带宽使用量为 (m * k * n // bn + k * n * m // bm + m * n) * element_size

因此,块大小 bmbkbn 对于性能至关重要。即使我们拥有世界上最大的矩阵,如果我们选择非常小的 bmbkbn,我们也会受到内存限制,因为每次调用内核时,我们都会有太少的 FLOPs 来隐藏后台发生的内存传输。

因此,直观的理解应该是:为了受计算限制,请尽可能增大块!主要有两个限制

  1. VMEM 使用:块越大,我们使用的 VMEM 就越多。块足够大时,我们将耗尽 VMEM。

  2. 流水线气泡:块相对于矩阵大小越大,我们的流水线中的循环迭代次数就越少。这将使流水线开始和结束时的气泡大小相对于整个流水线更大,并且这种开销可能不可忽略。

在 Pallas 中获得良好的矩阵乘法性能归结为选择合适的块大小来平衡此优化问题。在实践中,我们通常会遍历大量的候选块大小,分析内核,然后选择最佳的一个。

现在,让我们做一些非常简单的计时实验。我们将使用 timeit 来测量每个内核运行所需的时间。请注意,这是内核实际运行时间的上限,因为我们使用 timeit 测量了 Python 分派和其他开销。我们将计算我们以这种方式获得的 FLOP/s 量,并计算我们相对于芯片所提供性能的利用率百分比,我们将使用一些合理的块大小来验证我们的直觉。

import timeit

def benchmark(f, ntrials: int = 100):
  def run(*args, **kwargs):
    # Compile function first
    jax.block_until_ready(f(*args, **kwargs))
    # Time function
    result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
                           number=ntrials)
    time = result / ntrials
    # print(f"Time: {time}")
    return time
  return run

def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func):
  x = jnp.ones((m, k), dtype=dtype)
  y = jnp.ones((k, n), dtype=dtype)
  time = benchmark(mm_func)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()

print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00029766598949208854
Matmul FLOP/s:  7214407167121.377
FLOP/s utilization: 3.6621%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.011771515250438824
Matmul FLOP/s:  11675553278230.387
FLOP/s utilization: 5.9267%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09183577066054567
Matmul FLOP/s:  11972585626140.668
FLOP/s utilization: 6.0775%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012708659982308746
Matmul FLOP/s:  16897797651282.135
FLOP/s utilization: 8.5776%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.00088908776990138
Matmul FLOP/s:  154584235803001.88
FLOP/s utilization: 78.4692%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006099433819763363
Matmul FLOP/s:  180264539343531.62
FLOP/s utilization: 91.5048%

更大的块大小帮助很大!在更大的矩阵乘法中,我们获得了相当不错的利用率(80-90%),但最小的矩阵乘法似乎很难获得良好的性能。

让我们将其与 XLA 的矩阵乘法进行比较。我们不期望 Pallas 比 XLA 更好,因为 XLA 在生成矩阵乘法方面非常出色,但希望我们能接近。通过更仔细的块大小调整(留作未来工作),我们也可以达到 XLA 的性能。

print("================ XLA matmul ===================")
mm = jnp.matmul
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================ XLA matmul ===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00011943008983507753
Matmul FLOP/s:  17981093801113.996
FLOP/s utilization: 9.1275%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.0008272899803705514
Matmul FLOP/s:  166131533963991.34
FLOP/s utilization: 84.3307%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006047147869830951
Matmul FLOP/s:  181823175395037.44
FLOP/s utilization: 92.2960%

Pallas,经过一些非常基本的调整,已经非常接近 XLA 的性能数据了!通过尝试更多的块大小,我们应该能完全弥补差距。

矩阵乘法的模板化#

现在我们有了一个基本的矩阵乘法核,我们可以尝试将操作融合到其中。

融合右侧转置#

一个常见的第一个操作是融合转置。这是什么意思呢?假设我们想计算 x @ y.T 而不是 x @ y。天真地,我们可以首先计算 y.T,然后将其传递到我们的高效矩阵乘法核中。然而,操作 y.T 本身并非免费——它涉及复制 O(n^2) 数据。理想情况下,我们可以在进行矩阵乘法同时计算转置,只在一个内核中完成,即将其与矩阵乘法“融合”。

加速器通常支持融合右侧转置的原生矩阵乘法例程。例如,TPU v5e 的 MXU 允许我们对小数组执行 x @ y.T。我们可以使用 jax.lax.dot_general 调用此例程,这将比单独进行转置和矩阵乘法更高效。

def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  # dot_general expects a data structure (contraction_dims, batch_dims),
  # where contraction_dims are the set of dimensions for LHS and RHS that will
  # be contracted (reduced) in the matmul; batch_dims, on the other hand, are
  # looped over. The remaining dimensions will be the input and output dimension
  # of the matmul.
  if transpose_rhs:
    dims = ((1,), (1,)), ((), ())
  else:
    dims = ((1,), (0,)), ((), ())

  acc_ref[...] += jax.lax.dot_general(
      x_ref[...], y_ref[...], dims, preferred_element_type=jnp.float32,
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = acc_ref[...].astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'transpose_rhs'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
    transpose_rhs: bool = False,
):
  if transpose_rhs:
    y = y.swapaxes(0, 1)
    y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
  else:
    y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(matmul_kernel, nsteps=k // bk, transpose_rhs=transpose_rhs),
      grid_spec=pltpu.PrefetchScalarGridSpec(
        num_scalar_prefetch=0,
        in_specs=[
            pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
            y_block_spec,
        ],
        out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
        scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
        grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.CompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)

我们在 matmul 函数内部进行转置(y = y.swapaxes(0, 1))。这是因为在 JIT 编译的 JAX 计算中,维度顺序纯粹是逻辑上的,而不是物理上的,因此重新排列维度并不意味着物理布局的改变。但是,当我们将数组传递给 pallas_call 时,我们确实强制执行主要到次要的维度顺序约束。通过在 matmul 函数内部转置 y,我们要求 y 采用转置布局 (n, k) 而不是通常的 (k, n)。然而,用户仍将以(逻辑)(k, n) 维度传递数组。

注意:为了对转置进行基准测试,我们实际上希望 y 在将其传递给内核时处于物理转置布局中,这样我们就不会测量重新布局时间。在包装函数中,我们将在将其传递给 matmul 之前(逻辑上)将其转置回 (k, n),因为 matmul 期望逻辑上的 (k, n) 维度顺序。

def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func, transpose_rhs: bool = False):
  x = jnp.ones((m, k), dtype=dtype)
  if transpose_rhs:
    y = jnp.ones((n, k), dtype=dtype)
    @jax.jit
    def _wrapper(x, y):
      y = y.swapaxes(0, 1)
      return mm_func(x, y, transpose_rhs=True)
  else:
    y = jnp.ones((k, n), dtype=dtype)
    _wrapper = mm_func
  time = benchmark(_wrapper)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()

print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.0003029372810851783
Matmul FLOP/s:  7088872126624.065
FLOP/s utilization: 3.5984%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.012017967159627005
Matmul FLOP/s:  11436123235026.848
FLOP/s utilization: 5.8051%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09500920018996112
Matmul FLOP/s:  11572685861765.383
FLOP/s utilization: 5.8745%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012131539988331496
Matmul FLOP/s:  17701657415839.363
FLOP/s utilization: 8.9856%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.0008790623804088682
Matmul FLOP/s:  156347213275211.03
FLOP/s utilization: 79.3641%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006107717020204291
Matmul FLOP/s:  180020067095253.78
FLOP/s utilization: 91.3807%

看,尽管进行了额外的转置,我们仍然获得了相同的利用率!

融合激活函数#

融合激活函数也非常常见。这确保我们不会在高效的计算限制型矩阵乘法核之后紧接着一个慢速的内存限制型激活核。

def matmul_kernel(
    x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs, activation
):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  if transpose_rhs:
    dims = ((1,), (1,)), ((), ())
  else:
    dims = ((1,), (0,)), ((), ())

  acc_ref[...] += jax.lax.dot_general(
      x_ref[...],
      y_ref[...],
      dims,
      preferred_element_type=jnp.float32,
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = activation(acc_ref[...]).astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'activation'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
    transpose_rhs: bool = False,
    activation: Callable[[jax.Array], jax.Array] = lambda x: x,
):
  if transpose_rhs:
    y = y.swapaxes(0, 1)
    y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
  else:
    y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(
          matmul_kernel,
          nsteps=k // bk,
          transpose_rhs=transpose_rhs,
          activation=activation,
      ),
      grid_spec=pltpu.PrefetchScalarGridSpec(
          num_scalar_prefetch=0,
          in_specs=[
              pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
              y_block_spec,
          ],
          out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
          scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
          grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.CompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func, transpose_rhs: bool = False,
                   activation = lambda x: x):
  x = jnp.ones((m, k), dtype=dtype)
  if transpose_rhs:
    y = jnp.ones((n, k), dtype=dtype)
    @jax.jit
    def _wrapper(x, y):
      y = y.swapaxes(0, 1)
      return mm_func(x, y, transpose_rhs=True, activation=activation)
  else:
    y = jnp.ones((k, n), dtype=dtype)
    _wrapper = functools.partial(mm_func, activation=activation)
  time = benchmark(_wrapper)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()


activation = jax.nn.relu
print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00030103540048003196
Matmul FLOP/s:  7133658182976.541
FLOP/s utilization: 3.6211%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.011807117109419778
Matmul FLOP/s:  11640348122095.826
FLOP/s utilization: 5.9088%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09181861146935262
Matmul FLOP/s:  11974823079773.941
FLOP/s utilization: 6.0786%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012622540001757442
Matmul FLOP/s:  17013086492108.6
FLOP/s utilization: 8.6361%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.000896632740041241
Matmul FLOP/s:  153283442968721.44
FLOP/s utilization: 77.8089%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006130605939542875
Matmul FLOP/s:  179347953304919.88
FLOP/s utilization: 91.0396%

额外融合的激活几乎不影响我们的利用率!

结论#

在本指南中,我们介绍了如何使用 Pallas 在 TPU 上编写高效的矩阵乘法。我们讨论了分块矩阵乘法和流水线技术,如何分析 TPU 矩阵乘法的性能,以及如何编写高效的 bf16 矩阵乘法。最后,我们通过模板化矩阵乘法来支持融合转置和融合激活函数。

留给读者的练习

  • 添加对输入融合的支持。有时我们希望将一个操作融合到矩阵乘法的输入中。尝试进一步模板化矩阵乘法以支持此功能。

  • 添加对 int8 矩阵乘法的支持。TPU v5 支持原生 int8 矩阵乘法,其 FLOPs 是 bf16 的两倍。尝试添加支持并查看可能的利用率。

  • 添加 matmul 函数的反向传播支持。您可以使用 jax.custom_vjp 来实现。