矩阵乘法#
在本指南中,我们将使用 Pallas 编写一个矩阵乘法例程。我们还将讨论如何在 TPU 上思考矩阵乘法性能,以及如何模板化矩阵乘法内核以融合操作。
#@title Imports
import functools
from typing import Callable
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax import random
import jax.numpy as jnp
import numpy as np
背景#
矩阵乘法是现代深度学习和语言建模的核心基本线性代数运算。我们希望使用 TPU 和 GPU 等专用加速器尽可能快速地执行矩阵乘法,它们都具有用于快速矩阵乘法的专用单元。
为了有效地利用 TPU 进行矩阵乘法,我们需要了解一些背景概念:块矩阵乘法、分块和流水线。
块矩阵乘法#
假设我们想实现 matmul(x, y),它通用地将一个 (m, k) 的数组与一个 (k, n) 的数组相乘,但有一个变通。我们只允许使用原始函数 matmul_small,它将小矩阵相乘(例如 m, k, n <= 256)。我们该如何做到呢?
矩阵乘法的一个良好特性是,输出的每个块都可以表示为输入行块和列块的若干个小矩阵乘法的总和。形式上,如果我们有输入数组 \(x \in \mathbb{R}^{m \times k}\) 和 \(y \in \mathbb{R}^{k \times n}\) 以及输出 \(z \in \mathbb{R}^{m \times n}\),我们沿着大小为 \(b_m, b_k, b_n\) 的维度将它们分解成块。
例如,\(x\) 将被分解为
其中 \(x_{ik} \in \mathbb{R}^{b_m \times b_k}\)。(我们也可以类似地分解 \(y\) 和 \(z\)。)
对于一个特定的输出块 \(z_{ij}\),我们可以计算它为
因此,每个输出块 \(z_{ij}\) 是若干个小块矩阵乘法 \(x_{ik} y_{kj}\) 的总和。以下是我们如何在 NumPy 中实现此算法
def matmul_small(x: np.ndarray, y: np.ndarray) -> np.ndarray:
m, k, n = x.shape[0], x.shape[1], y.shape[0]
assert m <= 256
assert k <= 256
assert n <= 256
return np.matmul(x, y)
def block_matmul(
x: np.ndarray,
y: np.ndarray,
*,
bm: int = 256,
bk: int = 256,
bn: int = 256,
) -> np.ndarray:
m, k = x.shape
_, n = y.shape
z = np.zeros((m, n), dtype=x.dtype)
for m_i in range(m // bm):
for n_i in range(n // bn):
for k_i in range(k // bk):
m_slice = slice(m_i * bm, (m_i + 1) * bm)
k_slice = slice(k_i * bk, (k_i + 1) * bk)
n_slice = slice(n_i * bn, (n_i + 1) * bn)
x_block = x[m_slice, k_slice]
y_block = y[k_slice, n_slice]
z[m_slice, n_slice] += matmul_small(x_block, y_block)
return z
我们的 block_matmul 函数现在应该可以处理大于 256 的输入(尽管我们假设输入维度能被 256 整除)。
m, k, n = 4096, 4096, 4096
x = np.random.uniform(size=(m, k)).astype(np.float32)
y = np.random.uniform(size=(k, n)).astype(np.float32)
np.testing.assert_allclose(x @ y, block_matmul(x, y), atol=1e-6, rtol=1e-6)
block_matmul 通过观察每个大小为 (bm, bn) 的输出块可以通过累积几个大小为 (bm, bk) x (bk, bn) 的矩阵乘法来计算,从而将矩阵乘法分解为许多小的矩阵乘法。
TPU 和 GPU 就像这样执行矩阵乘法!它们原生支持类似于 matmul_small 的小矩阵乘法,因此为了在执行更大的矩阵乘法时利用这些硬件,我们将应用 block_matmul 分解。
分块和流水线#
在 上一篇指南 中,我们介绍了 Pallas 中计算的分块和流水线如何工作。为了确保我们的计算单元始终在工作,并且不会因内存传输而停顿,我们将下一轮内核的内存传输与当前一轮的传输重叠。
在 Pallas 中,我们通过 BlockSpec 和 grid 来指定这一点。请注意,在块矩阵乘法算法中,我们已经有一个嵌套的 for 循环。我们可以通过 grid 在 Pallas 中指定它。块矩阵乘法中的切片也可以通过 BlockSpec 指定。
你的第一个矩阵乘法内核#
将所有内容整合在一起,这里是一个块矩阵乘法内核的实现,它将内存传输与计算流水线化。我们创建了一个 3 维网格,对应于 NumPy 代码中的 3 个嵌套循环。请注意,虽然 MXU 只能执行小块的乘法,但 Pallas 会自动处理更大的块并自动将它们分块到 MXU 上。
网格的最后一个维度对应于矩阵乘法的收缩维度,并且是一个归约维度,因此我们需要确保初始化累加器。
def matmul_kernel(x_ref, y_ref, z_ref):
@pl.when(pl.program_id(2) == 0)
def _():
z_ref[...] = jnp.zeros_like(z_ref)
z_ref[...] += x_ref[...] @ y_ref[...]
def matmul(
x: jax.Array,
y: jax.Array,
*,
bm: int = 128,
bk: int = 128,
bn: int = 128,
):
m, k = x.shape
_, n = y.shape
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
in_specs=[pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],
out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
grid=(m // bm, n // bn, k // bk),
compiler_params=pltpu.CompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.float32)
y = random.normal(k2, (k, n), dtype=jnp.float32)
np.testing.assert_array_equal(x @ y, matmul(x, y))
矩阵乘法性能#
让我们来分析一下矩阵乘法的性能。当我们考虑矩阵乘法的性能时,通常关心两件事:浮点运算 (FLOPs) 的总数和内存带宽的使用量。从关于 TPU 和流水线 的指南中,我们看到为了使用 TPU(以及通用 ML 加速器)上的高效计算单元,我们需要将输入从 HBM 复制到 VMEM,更接近计算单元。这种 HBM 的复制需要时间,一个高效的内核应该花费大部分时间在计算上,而不是等待这些传输。内存带宽衡量数据传输的速率。
快速说明:在本指南中,我们将讨论浮点运算,但想区分 FLOPs 与 FLOP/s。当我们说“FLOPs”时,我们是指“浮点运算”,即运算次数。当我们说“FLOP/s”时,我们是指“每秒浮点运算次数”,即执行浮点运算的速率。
(m, k) x (k, n) 矩阵乘法的 FLOPs(大约)是 2 * m * k * n。(技术上是 n * m * (2k - 1),但对于足够大的 k,我们的近似是足够的。)
矩阵乘法的最小内存带宽使用量(假设 float32)是输入的总大小(复制到 VMEM)加上输出的大小(复制到 HBM)。因此,最小带宽使用量是 (m * k + k * n + m * n) * 4 bytes/float32。如果多次重新读取输入,内存使用量可能会更大,这通常是这种情况。
一个观察是,矩阵乘法的 FLOPs 随其输入呈三次增长,而最小带宽使用量随其输入呈二次增长。直观地说,这意味着 FLOPs 比带宽增长得更快,也就是说,我们的矩阵乘法越大,相对于复制的计算就越多。
def matmul_flops(m: int, k: int, n: int):
return 2 * m * k * n
def matmul_membw(m: int, k: int, n: int, dtype: jnp.dtype):
return (m * k + k * n + m * n) * np.dtype(dtype).itemsize
print(matmul_flops(1024, 1024, 1024))
print(matmul_membw(1024, 1024, 1024, jnp.float32))
2147483648
12582912
现在我们已经能够计算矩阵乘法的总 FLOPs 和(最小)内存带宽使用量,让我们看看实际的 TPU 能处理什么。
此笔记本运行在 TPU v5e 芯片上,因此我们将使用 v5e 的数字(如果您运行此笔记本,您的数字可能会有所不同)。TPU v5e 拥有 197 TFLOP/s 的 bf16/f32 计算能力和 819 GB/s 的内存带宽。通过查看这些数字的比率(称为算术强度),我们可以得到“FLOPs / 内存带宽使用量”比率在变为 IO 限制之前可以达到的下限(TPU v5e 上约为 240 FLOPs/byte)。
v5e_flops = 197e12
v5e_membw = 819e9
v5e_op_intensity = v5e_flops / v5e_membw # ~240.5
粗略地说,这些数字告诉我们,一个矩阵乘法的 FLOPs 应该花费 2 * m * k * n / (197 TFLOP/s) 秒,而到 VMEM 的复制应该花费 (m*k + k*n + m*n) * 4 bytes / 819GB/s 秒。
def matmul_flops_intensity(m: int, k: int, n: int, dtype: jnp.dtype):
flops = matmul_flops(m, k, n)
membw = matmul_membw(m, k, n, dtype)
return flops / membw
这个基本计算告诉我们,我们将能够多有效地使用我们的 MXU。如果我们的矩阵乘法运算强度低于芯片的能力,那么我们的计算将是 *内存受限* 的,即我们的计算单元在等待值传输时将处于空闲状态。如果矩阵乘法强度高于芯片的能力,那么我们将是 *计算受限* 的。
由于矩阵乘法的 FLOPs 随输入尺寸呈三次增长,而内存带宽使用量呈二次增长,我们预计随着尺寸越来越大,我们将变得计算受限,但这个交叉点非常重要!假设我们正在执行一个 (1024, 1024) x (1024, 1024) 的 float32 矩阵乘法。
print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.float32)} flops/byte")
170.66666666666666 flops/byte
我们的矩阵乘法 FLOPs 强度低于芯片的能力。这不好!这种类型的矩阵乘法很可能受内存限制。但是,如果我们的输入和输出更大呢?在某个时候,当我们的矩阵乘法足够大时,我们将从内存受限转变为计算受限。例如,如果我们有一个矩阵乘法,其中 m = k = n,当 2m**3 / 12m**2 > 240 或当 m = k = n > 1440 时(在 TPU v5e 上),我们将发生交叉。
bfloat16 矩阵乘法#
为了更容易使矩阵乘法在 TPU 上处于计算受限状态,我们也可以为输入和输出使用更小的数据类型。我们之前的例子使用了 float32 输入和输出,但 TPU v5e 也支持 bfloat16 数据类型(16 位浮点格式,也称为 bf16)用于矩阵乘法。在 TPU v5e 上,我们将获得相同的 FLOP/s,但将 *内存带宽使用量减半*。这使得对较小的矩阵更容易达到计算受限状态。让我们看看 1024 x 1024 x 1024 bf16 矩阵乘法的强度是多少
print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.bfloat16)} flops/byte")
341.3333333333333 flops/byte
我们现在有一个计算受限的矩阵乘法了!
让我们为矩阵乘法内核添加 bf16 支持。
原生的 MXU bf16 矩阵乘法例程接受两个 bf16 矩阵作为输入,并在 f32 中累加它们。我们将通过将 preferred_element_type=jnp.float32 传递给 jnp.matmul 来触发此例程。我们还需要一个 f32 类型的累加器 Ref。然后,在将输出写回 HBM 之前,我们将将其向下转换为 bf16。这样,我们既不会丢失精度,也不会进行额外的类型转换,并且仍然保留 bf16 的内存带宽节省。
请注意,目前分配临时空间的唯一方法是通过
pltpu.PrefetchScalarGridSpec。暂时不必担心它确切的功能——现在您只需要知道它允许您在 VMEM 中分配临时空间。
def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps):
@pl.when(pl.program_id(2) == 0)
def _():
acc_ref[...] = jnp.zeros_like(acc_ref)
acc_ref[...] += jnp.dot(
x_ref[...], y_ref[...], preferred_element_type=jnp.float32
)
@pl.when(pl.program_id(2) == nsteps - 1)
def _():
z_ref[...] = acc_ref[...].astype(z_ref.dtype)
@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
def matmul(
x: jax.Array,
y: jax.Array,
*,
bm: int = 128,
bk: int = 128,
bn: int = 128,
):
m, k = x.shape
_, n = y.shape
return pl.pallas_call(
functools.partial(matmul_kernel, nsteps=k // bk),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)),
],
out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
compiler_params=pltpu.CompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.bfloat16)
y = random.normal(k2, (k, n), dtype=jnp.bfloat16)
np.testing.assert_array_equal(x @ y, matmul(x, y))
流水线化内核的性能#
我们上面的关于 FLOPs 与内存使用的分析适用于粗粒度级别,即当我们查看整个矩阵乘法的大小时。然而,请记住,实际上,我们正在流水线化块矩阵乘法的执行,这意味着我们有一个循环,在这个循环中我们使用更小的块进行矩阵乘法。
这意味着我们实际上关心的是每个内核实例的 FLOPs 与内存带宽的使用量,而不是全局的 FLOPs 与内存带宽的使用量。
此外,在分块矩阵乘法运算时,相同的数值可能会被多次从内存中读取。具体来说,内核第一个操作数的内存带宽是 (bm * bk),它需要乘以网格的维度,即 (bm * bk) * m // bm * n // bn * k // bk = m * k * n // bn。第二个操作数类似,总带宽使用量为 (m * k * n // bn + k * n * m // bm + m * n) * element_size。
因此,块大小 bm、bk、bn 对于性能至关重要。即使我们拥有世界上最大的矩阵,如果我们选择非常小的 bm、bk 和 bn,我们将受到内存限制,因为每次调用内核时,我们拥有的 FLOPs 不足以隐藏后台进行的内存传输。
因此,直觉应该是:为了达到计算受限,使块尽可能大!有两个主要限制
VMEM 使用量:块越大,VMEM 使用量越多。当块足够大时,我们会用完。
流水线气泡:块相对于矩阵大小越大,流水线中的迭代次数就越少。这将导致流水线开始和结束时的气泡大小相对于整个流水线更大,并且这种开销可能并非微不足道。
在 Pallas 中获得良好的矩阵乘法性能归结为选择好的块大小来平衡这个优化问题。实际上,我们经常在大量候选块大小上进行扫描,分析内核性能,然后选择最佳的。
现在,让我们进行一些非常简单的计时实验。我们将使用 timeit 来测量运行每个内核所需的时间。请注意,这是内核实际运行时间的上限,因为我们使用 timeit 来测量 Python 分派和其他开销。我们将计算由此获得的 FLOP/s 量,并计算与芯片提供的性能相比的利用率百分比,我们将使用一些合理的块大小来验证我们的直觉。
import timeit
def benchmark(f, ntrials: int = 100):
def run(*args, **kwargs):
# Compile function first
jax.block_until_ready(f(*args, **kwargs))
# Time function
result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
number=ntrials)
time = result / ntrials
# print(f"Time: {time}")
return time
return run
def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
mm_func):
x = jnp.ones((m, k), dtype=dtype)
y = jnp.ones((k, n), dtype=dtype)
time = benchmark(mm_func)(x, y)
print(f"----- {m} x {k} x {n} -----")
print("Matmul time: ", time)
mm_flops = matmul_flops(m, k, n) / time
print("Matmul FLOP/s: ", mm_flops)
print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
print()
print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time: 0.00029766598949208854
Matmul FLOP/s: 7214407167121.377
FLOP/s utilization: 3.6621%
----- 4096 x 4096 x 4096 -----
Matmul time: 0.011771515250438824
Matmul FLOP/s: 11675553278230.387
FLOP/s utilization: 5.9267%
----- 8192 x 8192 x 8192 -----
Matmul time: 0.09183577066054567
Matmul FLOP/s: 11972585626140.668
FLOP/s utilization: 6.0775%
================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time: 0.00012708659982308746
Matmul FLOP/s: 16897797651282.135
FLOP/s utilization: 8.5776%
----- 4096 x 4096 x 4096 -----
Matmul time: 0.00088908776990138
Matmul FLOP/s: 154584235803001.88
FLOP/s utilization: 78.4692%
----- 8192 x 8192 x 8192 -----
Matmul time: 0.006099433819763363
Matmul FLOP/s: 180264539343531.62
FLOP/s utilization: 91.5048%
更大的块大小非常有帮助!在较大的矩阵乘法中,我们获得了相当不错的利用率(80-90%),但最小的矩阵乘法似乎很难获得良好的性能。
让我们将其与 XLA 的矩阵乘法进行比较。我们不期望 Pallas 比 XLA 更好,因为 XLA 非常擅长生成矩阵乘法,但希望我们能接近。通过更仔细的块大小调整(留作未来工作),我们也可以达到 XLA 的性能。
print("================ XLA matmul ===================")
mm = jnp.matmul
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================ XLA matmul ===================
----- 1024 x 1024 x 1024 -----
Matmul time: 0.00011943008983507753
Matmul FLOP/s: 17981093801113.996
FLOP/s utilization: 9.1275%
----- 4096 x 4096 x 4096 -----
Matmul time: 0.0008272899803705514
Matmul FLOP/s: 166131533963991.34
FLOP/s utilization: 84.3307%
----- 8192 x 8192 x 8192 -----
Matmul time: 0.006047147869830951
Matmul FLOP/s: 181823175395037.44
FLOP/s utilization: 92.2960%
Pallas 经过一些非常基本的调整,性能非常接近 XLA 的性能!通过尝试更多的块大小,我们应该能够完全缩小差距。
矩阵乘法的模板化#
现在我们有了一个基本的矩阵乘法内核,我们可以尝试将操作融合到其中。
右侧转置融合#
一个常见的第一个操作是融合转置。这是什么意思?假设我们要计算 x @ y.T 而不是 x @ y。我们可以先计算 y.T,然后将其传递给我们的高效矩阵乘法内核。然而,y.T 操作本身并非免费的——它涉及复制 O(n^2) 的数据。理想情况下,我们可以在进行矩阵乘法的同时,在单个内核中计算转置,即将其与矩阵乘法“融合”。
加速器通常支持原生的矩阵乘法例程,这些例程可以融合右侧转置。例如,TPU v5e 的 MXU 允许我们为小数组执行 x @ y.T。我们可以使用 jax.lax.dot_general 调用此例程,这比分别进行转置然后矩阵乘法更有效。
def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs):
@pl.when(pl.program_id(2) == 0)
def _():
acc_ref[...] = jnp.zeros_like(acc_ref)
# dot_general expects a data structure (contraction_dims, batch_dims),
# where contraction_dims are the set of dimensions for LHS and RHS that will
# be contracted (reduced) in the matmul; batch_dims, on the other hand, are
# looped over. The remaining dimensions will be the input and output dimension
# of the matmul.
if transpose_rhs:
dims = ((1,), (1,)), ((), ())
else:
dims = ((1,), (0,)), ((), ())
acc_ref[...] += jax.lax.dot_general(
x_ref[...], y_ref[...], dims, preferred_element_type=jnp.float32,
)
@pl.when(pl.program_id(2) == nsteps - 1)
def _():
z_ref[...] = acc_ref[...].astype(z_ref.dtype)
@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'transpose_rhs'])
def matmul(
x: jax.Array,
y: jax.Array,
*,
bm: int = 128,
bk: int = 128,
bn: int = 128,
transpose_rhs: bool = False,
):
if transpose_rhs:
y = y.swapaxes(0, 1)
y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
else:
y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
m, k = x.shape
_, n = y.shape
return pl.pallas_call(
functools.partial(matmul_kernel, nsteps=k // bk, transpose_rhs=transpose_rhs),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
y_block_spec,
],
out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
compiler_params=pltpu.CompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
我们在 matmul 函数内部进行转置(y = y.swapaxes(0, 1))。这是因为在 JIT 编译的 JAX 计算内部,维度顺序纯粹是 *逻辑上* 的,而不是物理上的,因此重新排列维度并不意味着物理布局的差异。然而,当我们传递一个数组到一个 pallas_call 时,我们会强制执行从主到次的维度排序约束。通过在 matmul 函数内部转置 y,我们要求 y 采用转置布局 (n, k) 而不是通常的 (k, n)。用户仍然会以(逻辑上的)(k, n) 维度传入数组。
注意:为了基准测试转置,我们实际上希望在将 y 传递给内核时,它处于物理转置布局,这样我们就不会衡量重排布局的时间。在包装函数中,我们将在(逻辑上)将其转置回 (k, n),然后再传递给 matmul,因为 matmul 期望逻辑上的 (k, n) 维度排序。
def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
mm_func, transpose_rhs: bool = False):
x = jnp.ones((m, k), dtype=dtype)
if transpose_rhs:
y = jnp.ones((n, k), dtype=dtype)
@jax.jit
def _wrapper(x, y):
y = y.swapaxes(0, 1)
return mm_func(x, y, transpose_rhs=True)
else:
y = jnp.ones((k, n), dtype=dtype)
_wrapper = mm_func
time = benchmark(_wrapper)(x, y)
print(f"----- {m} x {k} x {n} -----")
print("Matmul time: ", time)
mm_flops = matmul_flops(m, k, n) / time
print("Matmul FLOP/s: ", mm_flops)
print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
print()
print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)
print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time: 0.0003029372810851783
Matmul FLOP/s: 7088872126624.065
FLOP/s utilization: 3.5984%
----- 4096 x 4096 x 4096 -----
Matmul time: 0.012017967159627005
Matmul FLOP/s: 11436123235026.848
FLOP/s utilization: 5.8051%
----- 8192 x 8192 x 8192 -----
Matmul time: 0.09500920018996112
Matmul FLOP/s: 11572685861765.383
FLOP/s utilization: 5.8745%
================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time: 0.00012131539988331496
Matmul FLOP/s: 17701657415839.363
FLOP/s utilization: 8.9856%
----- 4096 x 4096 x 4096 -----
Matmul time: 0.0008790623804088682
Matmul FLOP/s: 156347213275211.03
FLOP/s utilization: 79.3641%
----- 8192 x 8192 x 8192 -----
Matmul time: 0.006107717020204291
Matmul FLOP/s: 180020067095253.78
FLOP/s utilization: 91.3807%
看看我们在额外的转置下获得了相同的利用率!
激活函数融合#
融合激活函数也很常见。这可以确保我们不会用一个缓慢的、内存受限的激活函数内核跟在一个高效的、计算受限的矩阵乘法内核后面。
def matmul_kernel(
x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs, activation
):
@pl.when(pl.program_id(2) == 0)
def _():
acc_ref[...] = jnp.zeros_like(acc_ref)
if transpose_rhs:
dims = ((1,), (1,)), ((), ())
else:
dims = ((1,), (0,)), ((), ())
acc_ref[...] += jax.lax.dot_general(
x_ref[...],
y_ref[...],
dims,
preferred_element_type=jnp.float32,
)
@pl.when(pl.program_id(2) == nsteps - 1)
def _():
z_ref[...] = activation(acc_ref[...]).astype(z_ref.dtype)
@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'activation'])
def matmul(
x: jax.Array,
y: jax.Array,
*,
bm: int = 128,
bk: int = 128,
bn: int = 128,
transpose_rhs: bool = False,
activation: Callable[[jax.Array], jax.Array] = lambda x: x,
):
if transpose_rhs:
y = y.swapaxes(0, 1)
y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
else:
y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
m, k = x.shape
_, n = y.shape
return pl.pallas_call(
functools.partial(
matmul_kernel,
nsteps=k // bk,
transpose_rhs=transpose_rhs,
activation=activation,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
y_block_spec,
],
out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
compiler_params=pltpu.CompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
mm_func, transpose_rhs: bool = False,
activation = lambda x: x):
x = jnp.ones((m, k), dtype=dtype)
if transpose_rhs:
y = jnp.ones((n, k), dtype=dtype)
@jax.jit
def _wrapper(x, y):
y = y.swapaxes(0, 1)
return mm_func(x, y, transpose_rhs=True, activation=activation)
else:
y = jnp.ones((k, n), dtype=dtype)
_wrapper = functools.partial(mm_func, activation=activation)
time = benchmark(_wrapper)(x, y)
print(f"----- {m} x {k} x {n} -----")
print("Matmul time: ", time)
mm_flops = matmul_flops(m, k, n) / time
print("Matmul FLOP/s: ", mm_flops)
print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
print()
activation = jax.nn.relu
print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)
print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time: 0.00030103540048003196
Matmul FLOP/s: 7133658182976.541
FLOP/s utilization: 3.6211%
----- 4096 x 4096 x 4096 -----
Matmul time: 0.011807117109419778
Matmul FLOP/s: 11640348122095.826
FLOP/s utilization: 5.9088%
----- 8192 x 8192 x 8192 -----
Matmul time: 0.09181861146935262
Matmul FLOP/s: 11974823079773.941
FLOP/s utilization: 6.0786%
================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time: 0.00012622540001757442
Matmul FLOP/s: 17013086492108.6
FLOP/s utilization: 8.6361%
----- 4096 x 4096 x 4096 -----
Matmul time: 0.000896632740041241
Matmul FLOP/s: 153283442968721.44
FLOP/s utilization: 77.8089%
----- 8192 x 8192 x 8192 -----
Matmul time: 0.006130605939542875
Matmul FLOP/s: 179347953304919.88
FLOP/s utilization: 91.0396%
额外的融合激活几乎不影响我们的利用率!
结论#
在本指南中,我们介绍了如何在 TPU 上使用 Pallas 编写高效的矩阵乘法。我们讨论了块矩阵乘法和流水线,如何分析 TPU 矩阵乘法的性能,以及如何编写高效的 bf16 矩阵乘法。最后,我们通过模板化矩阵乘法来支持融合转置和融合激活函数。
留给读者的练习
添加对输入融合的支持。有时我们希望将操作融合到矩阵乘法的输入中。尝试进一步模板化矩阵乘法以支持此功能。
添加对
int8矩阵乘法的支持。TPU v5 支持原生的int8矩阵乘法,其 FLOPs 是bf16的两倍。尝试为其添加支持,并查看可能的利用率。为
matmul函数添加反向传播支持。您可以使用jax.custom_vjp来实现此目的。