快速入门:如何在 JAX 中思考#

Open in Colab Open in Kaggle

JAX 是一个用于面向数组数值计算(类似于 NumPy)的库,具有自动微分和 JIT 编译功能,可实现高性能机器学习研究。.

本文档快速概述了 JAX 的基本特性,以便您开始使用 JAX。

  • JAX 为在 CPU、GPU 或 TPU 上运行的计算(无论是本地还是分布式环境)提供了统一的类似 NumPy 的接口。

  • JAX 通过 Open XLA(一个开源机器学习编译器生态系统)提供内置的即时 (JIT) 编译功能。

  • JAX 函数通过其自动微分转换支持梯度的有效评估。

  • JAX 函数可以自动向量化,以便有效地将它们映射到表示输入批次的数组上。

安装#

JAX 可以在 Linux、Windows 和 macOS 上直接从 Python Package Index 为 CPU 安装。

pip install jax

或者,对于 NVIDIA GPU

pip install -U "jax[cuda12]"

有关更详细的平台特定安装信息,请查看安装

JAX 与 NumPy 对比#

关键概念

  • JAX 提供了一个受 NumPy 启发的接口,以方便使用。

  • 通过鸭子类型,JAX 数组通常可以用作 NumPy 数组的直接替代品。

  • 与 NumPy 数组不同,JAX 数组始终是不可变的。

NumPy 提供了一个众所周知、功能强大的 API 用于处理数值数据。为了方便起见,JAX 提供了 jax.numpy,它与 NumPy API 高度相似,并为 JAX 提供了简单的入口。几乎所有可以用 numpy 完成的操作都可以用 jax.numpy 完成,后者通常以 jnp 的别名导入。

import jax.numpy as jnp

通过这种导入方式,您可以立即以类似于典型的 NumPy 程序的方式使用 JAX,包括使用 NumPy 风格的数组创建函数、Python 函数和运算符以及数组属性和方法。

import matplotlib.pyplot as plt

x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);
../_images/79a168e290c4923d7f26712aea6aed0738549c914a2e1776f0471e40b4d6e894.png

除了将 np 替换为 jnp 之外,代码块与您在 NumPy 中预期的内容相同,并且结果也相同。正如我们所看到的,JAX 数组通常可以直接替代 NumPy 数组用于绘图等操作。

数组本身是作为不同的 Python 类型实现的。

import numpy as np
import jax.numpy as jnp

x_np = np.linspace(0, 10, 1000)
x_jnp = jnp.linspace(0, 10, 1000)
type(x_np)
numpy.ndarray
type(x_jnp)
jaxlib._jax.ArrayImpl

Python 的鸭子类型允许 JAX 数组和 NumPy 数组在许多地方互换使用。然而,JAX 数组和 NumPy 数组之间存在一个重要区别:JAX 数组是不可变的,这意味着一旦创建,其内容就不能更改。

这是在 NumPy 中修改数组的一个示例:

# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)
[10  1  2  3  4  5  6  7  8  9]

JAX 中的等效操作会导致错误,因为 JAX 数组是不可变的。

%xmode minimal
Exception reporting mode: Minimal
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.net.cn/en/latest/_autosummary/jax.numpy.ndarray.at.html

为了更新单个元素,JAX 提供了一种索引更新语法,它返回一个更新后的副本。

y = x.at[0].set(10)
print(x)
print(y)
[0 1 2 3 4 5 6 7 8 9]
[10  1  2  3  4  5  6  7  8  9]

一旦您开始深入研究,会发现 JAX 数组和 NumPy 数组之间有一些区别。另请参阅:

JAX 数组 (jax.Array)#

关键概念

  • 使用 JAX API 函数创建数组。

  • JAX 数组对象有一个 devices 属性,指示数组存储在哪里。

  • JAX 数组可以在多个设备上进行分片以实现并行计算。

JAX 中的默认数组实现是 jax.Array。在许多方面,它与您可能熟悉的 NumPy 包中的 numpy.ndarray 类型相似,但它有一些重要的区别。

数组创建#

我们通常不直接调用 jax.Array 构造函数,而是通过 JAX API 函数创建数组。例如,jax.numpy 提供了熟悉的 NumPy 风格的数组构建功能,例如 jax.numpy.zerosjax.numpy.linspacejax.numpy.arange 等。

import jax
import jax.numpy as jnp

x = jnp.arange(5)
isinstance(x, jax.Array)
True

如果您在代码中使用 Python 类型注解,jax.Array 是 JAX 数组对象的合适注解(更多讨论请参见 jax.typing)。

数组设备与分片#

JAX 数组对象有一个 devices 方法,允许您检查数组内容存储在哪里。在最简单的情况下,这将是一个单独的 CPU 设备。

x.devices()
{CpuDevice(id=0)}

通常,数组可能会在多个设备上进行分片,可以通过 sharding 属性进行检查。

x.sharding
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

这里数组位于单个设备上,但通常 JAX 数组可以跨多个设备甚至多个主机进行分片。要了解有关分片数组和并行计算的更多信息,请参阅并行编程简介

使用 jax.jit 进行即时编译#

关键概念

  • 默认情况下,JAX 按顺序逐个执行操作。

  • 使用即时 (JIT) 编译装饰器,可以将一系列操作一起优化并一次性运行。

  • 并非所有 JAX 代码都可以进行 JIT 编译,因为它要求数组形状是静态的且在编译时已知。

JAX 透明地在 GPU 或 TPU 上运行(如果没有则回退到 CPU),所有 JAX 操作都以 XLA 形式表达。如果我们有一系列操作,我们可以使用 jax.jit 函数使用 XLA 编译器将这些操作序列一起编译。

例如,考虑这个函数,它以 jax.numpy 操作的形式对 2D 矩阵的行进行归一化。

import jax.numpy as jnp

def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)

可以使用 jax.jit 转换创建函数的即时编译版本。

from jax import jit
norm_compiled = jit(norm)

此函数返回与原始函数相同的结果,精度达到标准浮点数精度。

np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)
True

但由于编译(包括操作融合、避免分配临时数组以及许多其他技巧),JIT 编译情况下的执行时间可以快几个数量级。我们可以使用 IPython 的 %timeit 快速基准测试我们的函数,并使用 block_until_ready() 来考虑 JAX 的异步调度

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
324 μs ± 5.29 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
274 μs ± 658 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

话虽如此,jax.jit 确实有局限性:特别是,它要求所有数组都具有静态形状。这意味着某些 JAX 操作与 JIT 编译不兼容。

例如,此操作可以在逐操作模式下执行:

def get_negatives(x):
  return x[x < 0]

x = jnp.array(np.random.randn(10))
get_negatives(x)
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)

但如果您尝试在 JIT 模式下执行它,则会返回错误。

jit(get_negatives)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got bool[10]

See https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

这是因为该函数生成一个在编译时形状未知的数组:输出的大小取决于输入数组的值,因此它与 JIT 不兼容。

有关 JAX 中 JIT 编译的更多信息,请查看即时编译

JIT 机制:追踪与静态变量#

关键概念

  • JIT 和其他 JAX 转换通过追踪函数来确定其对特定形状和类型输入的影响。

  • 您不希望被追踪的变量可以标记为静态

要有效使用 jax.jit,了解其工作原理很有用。让我们在一个 JIT 编译的函数中放置一些 print() 语句,然后调用该函数:

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
Running f():
  x = JitTracer<float32[3,4]>
  y = JitTracer<float32[4]>
  result = JitTracer<float32[3]>
Array([0.25773212, 5.3623195 , 5.403243  ], dtype=float32)

请注意,print 语句执行了,但它打印的不是我们传递给函数的数据,而是它们的追踪器对象。

这些追踪器对象是 jax.jit 用于提取函数指定的操作序列的。基本追踪器是编码数组的形状dtype的替代品,但与值无关。然后,这个记录的计算序列可以在 XLA 中有效地应用于具有相同形状和 dtype 的新输入,而无需重新执行 Python 代码。

当我们再次在匹配的输入上调用编译后的函数时,无需重新编译,也不会打印任何内容,因为结果是在编译后的 XLA 中计算的,而不是在 Python 中。

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)

提取的操作序列编码在 JAX 表达式中,简称 jaxpr。您可以使用 jax.make_jaxpr 转换查看 jaxpr。

from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0:f32[]
    d:f32[4] = add b 1.0:f32[]
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) }

请注意这一点的一个结果:由于 JIT 编译是在没有数组内容信息的情况下完成的,因此函数中的控制流语句不能依赖于追踪值。例如,这会失败:

@jit
def f(x, neg):
  return -x if neg else x

f(1, True)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_4251/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

如果您不希望某些变量被追踪,可以为了 JIT 编译的目的将它们标记为静态。

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True)
Array(-1, dtype=int32, weak_type=True)

请注意,使用不同的静态参数调用 JIT 编译的函数会导致重新编译,因此函数仍然按预期工作。

f(1, False)
Array(1, dtype=int32, weak_type=True)

理解哪些值和操作是静态的,哪些将被追踪,是有效使用 jax.jit 的关键部分。

使用 jax.grad 求导#

关键概念

  • JAX 通过 jax.grad 转换提供自动微分功能。

  • jax.gradjax.jit 转换可以组合并任意混合使用。

除了通过 JIT 编译转换函数外,JAX 还提供其他转换。其中一种转换是 jax.grad(),它执行自动微分(autodiff)

from jax import grad

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
[0.25       0.19661197 0.10499357]

让我们用有限差分验证结果是否正确。

def first_finite_differences(f, x, eps=1E-3):
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))
[0.24998187 0.1965761  0.10502338]

grad()jit() 转换可以组合并任意混合使用。例如,虽然 sum_logistic 函数在之前的示例中直接进行了微分,但它也可以进行 JIT 编译,并且这些操作可以组合。我们可以更进一步:

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
-0.0353256

除了标量值函数,jax.jacobian() 转换还可以用于计算向量值函数的完整雅可比矩阵。

from jax import jacobian
print(jacobian(jnp.exp)(x_small))
[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]

对于更高级的自动微分操作,您可以使用 jax.vjp() 进行反向模式向量-雅可比积,以及使用 jax.jvp()jax.linearize() 进行正向模式雅可比-向量积。两者可以相互任意组合,也可以与其他 JAX 转换组合。例如,jax.jvp()jax.vjp() 用于定义正向模式 jax.jacfwd() 和反向模式 jax.jacrev(),分别用于计算正向和反向模式的雅可比矩阵。这里有一种组合它们的方法,可以高效地计算完整的 Hessian 矩阵:

from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))
[[-0.         -0.         -0.        ]
 [-0.         -0.09085776 -0.        ]
 [-0.         -0.         -0.07996249]]

这种组合在实践中会生成高效的代码;这或多或少是 JAX 内置的 jax.hessian() 函数的实现方式。

有关 JAX 中自动微分的更多信息,请查看自动微分

使用 jax.vmap() 进行自动向量化#

关键概念

  • JAX 通过 jax.vmap 转换提供自动向量化功能。

  • jax.vmap 可以与 jax.jit 组合以生成高效的向量化代码。

另一个有用的转换是 vmap(),即向量化映射。它具有沿数组轴映射函数的熟悉语义,但它不是显式地循环函数调用,而是将函数转换为原生向量化版本以获得更好的性能。当与 jit() 组合时,它的性能可以与手动重写函数以操作额外的批处理维度一样好。

我们将从一个简单的示例入手,使用 vmap() 将矩阵-向量积提升为矩阵-矩阵积。虽然在这种特定情况下手动操作很容易,但同样的技术可以应用于更复杂的函数。

from jax import random

key = random.key(1701)
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))

def apply_matrix(x):
  return jnp.dot(mat, x)

apply_matrix 函数将一个向量映射到一个向量,但我们可能希望将其按行应用于一个矩阵。我们可以在 Python 中通过循环批处理维度来实现,但这通常会导致性能不佳。

def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Naively batched
738 μs ± 3.27 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

熟悉 jnp.dot 函数的程序员可能会认识到,apply_matrix 可以重写以避免显式循环,利用 jnp.dot 的内置批处理语义。

import numpy as np

@jit
def batched_apply_matrix(batched_x):
  return jnp.dot(batched_x, mat.T)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched
27.9 μs ± 403 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

然而,随着函数变得更加复杂,这种手动批处理变得更加困难且容易出错。vmap() 转换旨在自动将函数转换为批处理感知版本。

from jax import vmap

@jit
def vmap_batched_apply_matrix(batched_x):
  return vmap(apply_matrix)(batched_x)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Auto-vectorized with vmap
39 μs ± 227 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

正如您所料,vmap() 可以与 jit()grad() 以及任何其他 JAX 转换任意组合。

有关 JAX 中自动向量化的更多信息,请查看自动向量化

伪随机数#

关键概念

  • JAX 使用与 NumPy 不同的伪随机数生成模型。

  • JAX 随机函数消耗一个随机 key,该 key 必须被分割以生成新的独立密钥。

  • JAX 的随机密钥模型是线程安全的,并避免了全局状态问题。

通常,JAX 力求与 NumPy 兼容,但伪随机数生成是一个显著的例外。NumPy 支持一种基于全局 state 的伪随机数生成方法,可以使用 numpy.random.seed 进行设置。全局随机状态与 JAX 的计算模型交互不佳,并且难以在不同线程、进程和设备之间强制实现可重现性。JAX 转而通过随机 key 显式跟踪状态。

from jax import random

key = random.key(43)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 43]

这个密钥实际上是 NumPy 隐藏状态对象的替代品,但我们将其显式传递给 jax.random 函数。重要的是,随机函数会消耗密钥,但不会修改它:将相同的密钥对象传递给随机函数将始终生成相同的样本。

print(random.normal(key))
print(random.normal(key))
0.07520543
0.07520543

经验法则是:永远不要重复使用密钥(除非您想要相同的输出)。

为了生成不同且独立的样本,您必须在将密钥传递给随机函数之前显式地 jax.random.split 密钥。

for i in range(3):
  new_key, subkey = random.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = random.normal(subkey)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.
draw 0: -1.9133632183074951
draw 1: -1.4749839305877686
draw 2: -0.36703771352767944

请注意,此代码是线程安全的,因为本地随机状态消除了涉及全局状态的可能竞态条件。jax.random.split 是一个确定性函数,它将一个 key 转换为多个独立(在伪随机性意义上)的密钥。

有关 JAX 中伪随机数的更多信息,请参阅伪随机数教程

这只是 JAX 功能的冰山一角。我们非常期待看到您能用它做些什么!