快速入门:如何用 JAX 思考#
JAX 是一个用于面向数组的数值计算库(*类似于* NumPy),并支持自动微分和 JIT 编译,以实现高性能的机器学习研究。.
本文档提供了 JAX 核心功能的快速概览,以便您可以开始使用 JAX。
JAX 提供统一的类 NumPy 接口,用于在 CPU、GPU 或 TPU 上运行的计算,支持本地或分布式设置。
JAX 通过 Open XLA 内置即时(JIT)编译功能。Open XLA 是一个开源的机器学习编译器生态系统。
JAX 函数通过其自动微分转换支持高效的梯度评估。
JAX 函数可以自动向量化,以便高效地将它们映射到表示输入批次的数组上。
安装#
JAX 可直接从 Python 包索引 安装以在 Linux、Windows 和 macOS 的 CPU 上使用。
pip install jax
或者,用于 NVIDIA GPU。
pip install -U "jax[cuda12]"
有关更详细的特定平台安装信息,请参阅 安装。
JAX 与 NumPy 对比#
关键概念
JAX 提供了一个受 NumPy 启发的接口,方便使用。
通过 鸭子类型,JAX 数组通常可以作为 NumPy 数组的直接替代品。
与 NumPy 数组不同,JAX 数组始终是不可变的。
NumPy 为处理数值数据提供了一个知名且强大的 API。为了方便起见,JAX 提供了 jax.numpy,它紧密地模仿了 NumPy API,并提供了轻松入门 JAX 的方式。几乎所有用 numpy 可以做的事情,都可以用 jax.numpy 来做,通常将其导入并别名为 jnp。
import jax.numpy as jnp
有了这个导入,您可以立即以类似典型 NumPy 程序的方式使用 JAX,包括使用 NumPy 风格的数组创建函数、Python 函数和运算符,以及数组属性和方法。
import matplotlib.pyplot as plt
x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);
代码块与您期望的 NumPy 代码块相同,只是将 np 替换为 jnp,结果也是相同的。如我们所见,JAX 数组通常可以像 NumPy 数组一样直接用于绘图等用途。
数组本身实现为不同的 Python 类型。
import numpy as np
import jax.numpy as jnp
x_np = np.linspace(0, 10, 1000)
x_jnp = jnp.linspace(0, 10, 1000)
type(x_np)
numpy.ndarray
type(x_jnp)
jaxlib._jax.ArrayImpl
Python 的鸭子类型允许 JAX 数组和 NumPy 数组在许多地方可以互换使用。然而,JAX 和 NumPy 数组之间有一个重要的区别:JAX 数组是不可变的,这意味着一旦创建,其内容就无法更改。
以下是在 NumPy 中修改数组的示例。
# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)
[10 1 2 3 4 5 6 7 8 9]
在 JAX 中等效的操作会产生错误,因为 JAX 数组是不可变的。
%xmode minimal
Exception reporting mode: Minimal
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.net.cn/en/latest/_autosummary/jax.numpy.ndarray.at.html
要更新单个元素,JAX 提供了一个 索引更新语法,该语法返回一个更新后的副本。
y = x.at[0].set(10)
print(x)
print(y)
[0 1 2 3 4 5 6 7 8 9]
[10 1 2 3 4 5 6 7 8 9]
一旦开始深入研究,您会发现 JAX 数组和 NumPy 数组之间存在一些差异。另请参阅:
关键概念,用于介绍 JAX 的关键概念,如转换、跟踪、jaxprs 和 Pytrees。
🔪 JAX - 那些容易出错的地方 🔪,介绍使用 JAX 时的常见陷阱。
JAX 数组(`jax.Array`#
关键概念
使用 JAX API 函数创建数组。
JAX 数组对象具有一个 `devices` 属性,用于指示数组存储在哪里。
JAX 数组可以*分片*到多个设备上进行并行计算。
JAX 中的默认数组实现是 jax.Array。在很多方面,它与您可能从 NumPy 包中熟悉的 numpy.ndarray 类型相似,但也有一些重要的区别。
数组创建#
我们通常不直接调用 `jax.Array` 构造函数,而是通过 JAX API 函数创建数组。例如,jax.numpy 提供了熟悉的 NumPy 风格的数组构造功能,如 `jax.numpy.zeros`、`jax.numpy.linspace`、`jax.numpy.arange` 等。
import jax
import jax.numpy as jnp
x = jnp.arange(5)
isinstance(x, jax.Array)
True
如果您在代码中使用 Python 类型注解,`jax.Array` 是 jax 数组对象的适当注解(有关更多讨论,请参阅 jax.typing)。
数组设备与分片#
JAX 数组对象有一个 `devices` 方法,可以让我们检查数组内容存储的位置。在最简单的情况下,这将是单个 CPU 设备。
x.devices()
{CpuDevice(id=0)}
一般来说,数组可以*分片*到多个设备上,其方式可以通过 `sharding` 属性进行检查。
x.sharding
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
这里数组位于单个设备上,但通常 JAX 数组可以分片到多个设备,甚至多个主机上。要了解更多关于分片数组和并行计算的信息,请参阅 并行编程入门。
使用 `jax.jit` 进行即时编译#
关键概念
默认情况下,JAX 会逐个操作地按顺序执行。
使用即时(JIT)编译装饰器,可以将一系列操作一起优化并一次性运行。
并非所有 JAX 代码都可以进行 JIT 编译,因为它要求数组形状在编译时是静态且已知的。
JAX 在 GPU 或 TPU 上透明运行(如果没有,则回退到 CPU),所有 JAX 操作都通过 XLA 来表达。如果我们有一系列操作,我们可以使用 jax.jit 函数使用 XLA 编译器将这一系列操作一起编译。
例如,考虑这个函数,它使用 `jax.numpy` 操作对二维矩阵的行进行归一化。
import jax.numpy as jnp
def norm(X):
X = X - X.mean(0)
return X / X.std(0)
可以使用 `jax.jit` 转换创建该函数的即时编译版本。
from jax import jit
norm_compiled = jit(norm)
此函数返回的结果与原始函数相同,在标准浮点精度范围内。
np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)
True
但由于编译(包括操作融合、避免分配临时数组以及许多其他技巧),JIT 编译的情况下的执行时间可以快几个数量级。我们可以使用 IPython 的 `%timeit` 来快速基准测试我们的函数,并使用 `block_until_ready()` 来考虑 JAX 的 异步分派。
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
225 μs ± 19.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
206 μs ± 3.77 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
即便如此,`jax.jit` 确实存在限制:特别是,它要求所有数组都具有静态形状。这意味着某些 JAX 操作与 JIT 编译不兼容。
例如,此操作可以在逐个操作模式下执行。
def get_negatives(x):
return x[x < 0]
x = jnp.array(np.random.randn(10))
get_negatives(x)
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)
但如果尝试在 jit 模式下执行,它会返回错误。
jit(get_negatives)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got bool[10]
See https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
这是因为该函数生成了一个形状在编译时未知的数组:输出的大小取决于输入数组的值,因此与 JIT 不兼容。
有关 JAX 中 JIT 编译的更多信息,请参阅 即时编译。
使用 `jax.grad` 进行求导#
关键概念
JAX 通过 `jax.grad` 转换提供自动微分。
`jax.grad` 和 `jax.jit` 转换可以任意组合和混合。
除了通过 JIT 编译转换函数外,JAX 还提供其他转换。其中一种转换是 jax.grad,它执行 自动微分(autodiff)。
from jax import grad
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25 0.19661197 0.10499357]
让我们通过有限差分来验证我们的结果是否正确。
def first_finite_differences(f, x, eps=1E-3):
return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
for v in jnp.eye(len(x))])
print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1964569 0.10502338]
`jax.grad` 和 `jax.jit` 转换可以任意组合和混合。例如,虽然在上一个示例中直接对 `sum_logistic` 函数进行了微分,但也可以对其进行 JIT 编译,并且这些操作可以组合。我们可以做得更进一步。
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.0353256
除了标量值函数,还可以使用 `jax.jacobian` 转换来计算向量值函数的完整雅可比矩阵。
from jax import jacobian
print(jacobian(jnp.exp)(x_small))
[[1. 0. 0. ]
[0. 2.7182817 0. ]
[0. 0. 7.389056 ]]
对于更高级的自动微分操作,您可以使用 `jax.vjp` 进行反向模式向量-雅可比乘积,以及 `jax.jvp` 和 `jax.linearize` 进行前向模式雅可比-向量乘积。这两种可以任意组合,并与其他 JAX 转换组合。例如,`jax.jvp` 和 `jax.vjp` 用于定义前向模式的 `jax.jacfwd` 和反向模式的 `jax.jacrev`,分别用于计算雅可比矩阵。以下是一种组合它们以生成有效计算完整 Hessian 矩阵的函数的方法:
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))
[[-0. -0. -0. ]
[-0. -0.09085776 -0. ]
[-0. -0. -0.07996249]]
这种组合在实践中会产生高效的代码;这大致就是 JAX 内置的 `jax.hessian` 函数的实现方式。
有关 JAX 中自动微分的更多信息,请参阅 自动微分。
使用 `jax.vmap` 进行自动向量化#
关键概念
JAX 通过 `jax.vmap` 转换提供自动向量化。
`jax.vmap` 可以与 `jax.jit` 组合以生成高效的向量化代码。
另一个有用的转换是 `jax.vmap`,即向量化映射。它具有沿着数组轴映射函数的熟悉语义,但不是显式循环函数调用,而是将函数转换为原生向量化版本以获得更好的性能。当与 `jax.jit` 组合时,它可以和手动重写函数以处理额外的批次维度一样高效。
我们将使用一个简单的例子,并通过 `jax.vmap` 将矩阵-向量乘积提升为矩阵-矩阵乘积。虽然在这个特定情况下手动完成很容易,但相同的技术可以应用于更复杂的函数。
from jax import random
key = random.key(1701)
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))
def apply_matrix(x):
return jnp.dot(mat, x)
`apply_matrix` 函数将一个向量映射到一个向量,但我们可能希望在矩阵中逐行应用它。我们可以通过在 Python 中循环批次维度来做到这一点,但这通常会导致性能低下。
def naively_batched_apply_matrix(v_batched):
return jnp.stack([apply_matrix(v) for v in v_batched])
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
393 μs ± 1.73 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
熟悉 `jnp.dot` 函数的程序员可能会认识到 `apply_matrix` 可以重写以避免显式循环,使用 `jnp.dot` 的内置批处理语义。
import numpy as np
@jit
def batched_apply_matrix(batched_x):
return jnp.dot(batched_x, mat.T)
np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
12.5 μs ± 160 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
然而,随着函数变得越来越复杂,这种手动批处理变得越来越困难且容易出错。`jax.vmap` 转换旨在自动将函数转换为支持批处理的版本。
from jax import vmap
@jit
def vmap_batched_apply_matrix(batched_x):
return vmap(apply_matrix)(batched_x)
np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
16.1 μs ± 79.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
正如您所期望的,`jax.vmap` 可以与 `jax.jit`、`jax.grad` 以及任何其他 JAX 转换任意组合。
有关 JAX 中自动向量化的更多信息,请参阅 自动向量化。
伪随机数#
关键概念
JAX 使用与 NumPy 不同的伪随机数生成模型。
JAX 随机函数使用一个必须进行拆分的随机 `key` 来生成新的独立 `key`。
JAX 的随机 key 模型是线程安全的,并避免了全局状态的问题。
总的来说,JAX 致力于与 NumPy 兼容,但伪随机数生成是一个显著的例外。NumPy 支持一种基于全局 `state` 的伪随机数生成方法,可以通过 `numpy.random.seed` 设置。全局随机状态与 JAX 的计算模型相互作用不佳,并且难以在不同线程、进程和设备之间强制执行可重现性。JAX 改为通过一个随机 `key` 来显式跟踪状态。
from jax import random
key = random.key(43)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 43]
这个 `key` 实际上是 NumPy 的隐藏状态对象的替代品,但我们将其显式传递给 `jax.random` 函数。重要的是,随机函数会消耗 `key`,但不会修改它:将相同的 `key` 对象传递给随机函数将始终生成相同的样本。
print(random.normal(key))
print(random.normal(key))
0.07520543
0.07520543
经验法则是:切勿重复使用 `key`(除非您想要相同的输出)。
为了生成不同且独立的样本,您必须在将其传递给随机函数之前显式地 jax.random.split `key`。
for i in range(3):
new_key, subkey = random.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = random.normal(subkey)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
draw 0: -1.9133632183074951
draw 1: -1.4749839305877686
draw 2: -0.36703771352767944
请注意,此代码是线程安全的,因为局部随机状态消除了涉及全局状态的潜在竞争条件。`jax.random.split` 是一个确定性函数,它将一个 `key` 转换为几个独立的(在伪随机意义上)`key`。
有关 JAX 中伪随机数的更多信息,请参阅 伪随机数教程。
这只是 JAX 功能的一小部分。我们非常期待看到您用它来做什么!