如何用 JAX 思考#

Open in Colab Open in Kaggle

JAX 提供了一个简单而强大的 API 用于编写加速的数值代码,但要在 JAX 中高效工作有时需要额外的考虑。本文档旨在帮助您从头开始了解 JAX 的运作方式,以便您可以更有效地使用它。

JAX 与 NumPy#

核心概念

  • JAX 为了方便起见,提供了受 NumPy 启发的接口。

  • 通过鸭子类型,JAX 数组通常可以用作 NumPy 数组的直接替代品。

  • 与 NumPy 数组不同,JAX 数组始终是不可变的。

NumPy 提供了一个众所周知的、强大的 API 用于处理数值数据。为了方便起见,JAX 提供了 jax.numpy,它与 numpy API 非常相似,并提供了轻松进入 JAX 的途径。几乎所有可以用 numpy 完成的事情都可以用 jax.numpy 完成。

import matplotlib.pyplot as plt
import numpy as np

x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np);
../_images/b2db475a8afa1d2e364a801f61f7b347b75a355e9da0be2f015a2d1aefdea45c.png
import jax.numpy as jnp

x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);
../_images/487cfe9c47318bd2e5849cf09dc8048af87a3364e9f0e0e524de8e950911888e.png

除了将 np 替换为 jnp 之外,代码块是相同的,并且结果也是相同的。正如我们所见,JAX 数组通常可以直接用于替代 NumPy 数组,例如绘图。

数组本身被实现为不同的 Python 类型

type(x_np)
numpy.ndarray
type(x_jnp)
jaxlib.xla_extension.ArrayImpl

Python 的鸭子类型允许 JAX 数组和 NumPy 数组在许多地方互换使用。

然而,JAX 和 NumPy 数组之间有一个重要的区别:JAX 数组是不可变的,这意味着一旦创建,它们的内容就无法更改。

这是一个在 NumPy 中修改数组的示例

# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)
[10  1  2  3  4  5  6  7  8  9]

在 JAX 中,等效的操作会导致错误,因为 JAX 数组是不可变的

%xmode minimal
Exception reporting mode: Minimal
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.net.cn/en/latest/_autosummary/jax.numpy.ndarray.at.html

对于更新单个元素,JAX 提供了一个索引更新语法,该语法返回一个更新后的副本

y = x.at[0].set(10)
print(x)
print(y)
[0 1 2 3 4 5 6 7 8 9]
[10  1  2  3  4  5  6  7  8  9]

NumPy、lax 和 XLA:JAX API 分层#

核心概念

  • jax.numpy 是一个提供熟悉接口的高级包装器。

  • jax.lax 是一个较低级的 API,它更严格且通常更强大。

  • 所有 JAX 操作都是根据 XLA(加速线性代数编译器)中的操作实现的。

如果您查看 jax.numpy 的源代码,您会看到所有操作最终都用 jax.lax 中定义的函数来表示。您可以将 jax.lax 视为一个更严格但通常更强大的 API,用于处理多维数组。

例如,虽然 jax.numpy 会隐式地提升参数以允许混合数据类型之间的操作,但 jax.lax 则不会

import jax.numpy as jnp
jnp.add(1, 1.0)  # jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0)  # jax.lax API requires explicit type promotion.
TypeError: lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).

如果直接使用 jax.lax,您必须在这些情况下显式地进行类型提升

lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)

除了这种严格性之外,jax.lax 还为一些比 NumPy 支持的更通用的操作提供了高效的 API。

例如,考虑一个一维卷积,它可以用 NumPy 这样表示

x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

在底层,此 NumPy 操作被转换为由 lax.conv_general_dilated 实现的更通用的卷积。

from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

这是一个批处理卷积操作,旨在高效处理深度神经网络中常用的卷积类型。它需要更多的样板代码,但比 NumPy 提供的卷积更灵活、更可扩展(有关 JAX 卷积的更多详细信息,请参阅JAX 中的卷积)。

本质上,所有 jax.lax 操作都是 XLA 中操作的 Python 包装器;例如,这里的卷积实现由 XLA:ConvWithGeneralPadding 提供。每个 JAX 操作最终都用这些基本的 XLA 操作表示,这使得即时 (JIT) 编译成为可能。

是否使用 JIT#

核心概念

  • 默认情况下,JAX 逐个按顺序执行操作。

  • 使用即时 (JIT) 编译装饰器,可以将一系列操作一起优化并一次运行。

  • 并非所有 JAX 代码都可以进行 JIT 编译,因为它要求数组形状在编译时是静态且已知的。

所有 JAX 操作都用 XLA 表示的事实允许 JAX 使用 XLA 编译器非常高效地执行代码块。

例如,考虑这个函数,它规范化一个二维矩阵的行,用 jax.numpy 操作表示

import jax.numpy as jnp

def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)

可以使用 jax.jit 转换创建该函数的即时编译版本

from jax import jit
norm_compiled = jit(norm)

此函数返回的结果与原始函数相同,精度达到标准浮点精度

np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)
True

但是由于编译(包括操作融合、避免分配临时数组以及其他许多技巧),在 JIT 编译的情况下,执行时间可能快几个数量级(请注意使用 block_until_ready() 来解释 JAX 的 异步调度

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
755 μs ± 4.35 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
367 μs ± 2.73 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

也就是说,jax.jit 确实有局限性:特别是,它要求所有数组都具有静态形状。这意味着某些 JAX 操作与 JIT 编译不兼容。

例如,此操作可以在逐个操作模式下执行

def get_negatives(x):
  return x[x < 0]

x = jnp.array(np.random.randn(10))
get_negatives(x)
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)

但是,如果您尝试在 jit 模式下执行它,它会返回一个错误

jit(get_negatives)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

这是因为该函数生成了一个在编译时其形状未知的数组:输出的大小取决于输入数组的值,因此它与 JIT 不兼容。

JIT 机制:跟踪和静态变量#

核心概念

  • JIT 和其他 JAX 转换通过跟踪一个函数来确定其对特定形状和类型的输入的影响。

  • 您不希望被跟踪的变量可以标记为静态

要有效地使用 jax.jit,了解其工作原理很有用。让我们在 JIT 编译的函数中放置一些 print() 语句,然后调用该函数

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace>
Array([0.25773212, 5.3623195 , 5.403243  ], dtype=float32)

请注意,print 语句会执行,但它打印的不是我们传递给函数的数据,而是代替它们的跟踪器对象。

这些跟踪器对象是 jax.jit 用于提取函数指定的操作序列的内容。基本跟踪器是代替项,用于编码数组的形状dtype,但与值无关。然后,可以将此记录的计算序列在 XLA 中高效地应用于具有相同形状和 dtype 的新输入,而无需重新执行 Python 代码。

当我们在匹配的输入上再次调用编译后的函数时,不需要重新编译,也不会打印任何内容,因为结果是在编译后的 XLA 中而不是在 Python 中计算的

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)

提取的操作序列被编码在 JAX 表达式中,或简称为 jaxpr。您可以使用 jax.make_jaxpr 转换来查看 jaxpr

from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) }

请注意此结果的一个结果:由于 JIT 编译是在没有数组内容信息的情况下完成的,因此函数中的控制流语句不能依赖于被跟踪的值。例如,这会失败

@jit
def f(x, neg):
  return -x if neg else x

f(1, True)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_2639/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

如果存在您不希望被跟踪的变量,则可以将它们标记为 JIT 编译的静态变量

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True)
Array(-1, dtype=int32, weak_type=True)

请注意,使用不同的静态参数调用 JIT 编译的函数会导致重新编译,因此该函数仍然按预期工作

f(1, False)
Array(1, dtype=int32, weak_type=True)

了解哪些值和操作将是静态的,哪些将被跟踪是有效使用 jax.jit 的关键部分。

静态操作与跟踪操作#

核心概念

  • 就像值可以是静态的或被跟踪的一样,操作也可以是静态的或被跟踪的。

  • 静态操作在编译时在 Python 中求值;跟踪操作在运行时在 XLA 中编译和求值。

  • 对于要静态化的操作,请使用 numpy;对于要跟踪的操作,请使用 jax.numpy

静态值和跟踪值之间的这种区别使得思考如何保持静态值保持静态变得很重要。考虑这个函数

import jax.numpy as jnp
from jax import jit

@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_2639/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /tmp/ipykernel_2639/1983583872.py:6 (f)

这失败并出现一个错误,指定找到了一个跟踪器,而不是一个整数类型的一维具体值序列。让我们向函数添加一些 print 语句,以了解为什么会发生这种情况

@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)
x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>

请注意,尽管 x 被跟踪,但 x.shape 是一个静态值。但是,当我们在该静态值上使用 jnp.arrayjnp.prod 时,它会变成一个被跟踪的值,此时它不能在需要静态输入的函数(如 reshape())中使用(请回忆:数组形状必须是静态的)。

一个有用的模式是对于应该是静态的操作(即在编译时完成的操作),使用 numpy,对于应该跟踪的操作(即在运行时编译和执行的操作),使用 jax.numpy。对于此函数,它可能如下所示

from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)

因此,在 JAX 程序中的一个标准约定是 import numpy as npimport jax.numpy as jnp,这样两个接口都可用,从而可以更精细地控制操作是以静态方式(使用 numpy,在编译时执行一次)还是以跟踪方式(使用 jax.numpy,在运行时进行优化)执行。