JAX 思考方式#
JAX 提供了一个简单而强大的 API 用于编写加速的数值代码,但有效地使用 JAX 有时需要额外的考虑。本文档旨在帮助从根本上理解 JAX 的运作方式,以便您可以更有效地使用它。
JAX vs. NumPy#
关键概念
JAX 为了方便起见,提供了受 NumPy 启发的接口。
通过鸭子类型,JAX 数组通常可以用作 NumPy 数组的直接替代品。
与 NumPy 数组不同,JAX 数组始终是不可变的。
NumPy 提供了一个众所周知的、强大的 API 用于处理数值数据。为了方便起见,JAX 提供了 jax.numpy
,它密切镜像了 numpy API,并提供了进入 JAX 的简易入口。几乎所有可以使用 numpy
完成的操作都可以使用 jax.numpy
完成
import matplotlib.pyplot as plt
import numpy as np
x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np);

import jax.numpy as jnp
x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);

除了将 np
替换为 jnp
之外,代码块是相同的,结果也相同。正如我们所见,JAX 数组通常可以直接代替 NumPy 数组使用,例如用于绘图。
数组本身被实现为不同的 Python 类型
type(x_np)
numpy.ndarray
type(x_jnp)
jaxlib.xla_extension.ArrayImpl
Python 的 鸭子类型 允许 JAX 数组和 NumPy 数组在许多地方互换使用。
然而,JAX 数组和 NumPy 数组之间有一个重要的区别:JAX 数组是不可变的,这意味着一旦创建,其内容就无法更改。
这是一个在 NumPy 中修改数组的示例
# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)
[10 1 2 3 4 5 6 7 8 9]
JAX 中的等效操作会导致错误,因为 JAX 数组是不可变的
%xmode minimal
Exception reporting mode: Minimal
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.net.cn/en/latest/_autosummary/jax.numpy.ndarray.at.html
对于更新单个元素,JAX 提供了一个 索引更新语法,它返回更新后的副本
y = x.at[0].set(10)
print(x)
print(y)
[0 1 2 3 4 5 6 7 8 9]
[10 1 2 3 4 5 6 7 8 9]
NumPy、lax 和 XLA:JAX API 分层#
关键概念
jax.numpy
是一个高级包装器,提供了一个熟悉的接口。jax.lax
是一个更低级别的 API,它更严格,通常更强大。所有 JAX 操作都是使用 XLA – 加速线性代数编译器中的操作实现的。
如果您查看 jax.numpy
的源代码,您会看到所有操作最终都用 jax.lax
中定义的函数表示。您可以将 jax.lax
视为一个更严格,但通常更强大的 API,用于处理多维数组。
例如,虽然 jax.numpy
会隐式提升参数以允许混合数据类型之间的操作,但 jax.lax
不会
import jax.numpy as jnp
jnp.add(1, 1.0) # jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0) # jax.lax API requires explicit type promotion.
TypeError: lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).
如果直接使用 jax.lax
,您将不得不在这种情况下显式地进行类型提升
lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)
除了这种严格性之外,jax.lax
还为某些比 NumPy 支持的更通用的操作提供了高效的 API。
例如,考虑一个 1D 卷积,它可以用 NumPy 这样表示
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)
在底层,此 NumPy 操作被转换为由 lax.conv_general_dilated
实现的更通用的卷积
from jax import lax
result = lax.conv_general_dilated(
x.reshape(1, 1, 3).astype(float), # note: explicit promotion
y.reshape(1, 1, 10),
window_strides=(1,),
padding=[(len(y) - 1, len(y) - 1)]) # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)
这是一个批量卷积操作,旨在有效地处理深度神经网络中常用的卷积类型。它需要更多的样板代码,但比 NumPy 提供的卷积更灵活和可扩展(有关 JAX 卷积的更多详细信息,请参阅 JAX 中的卷积)。
在其核心,所有 jax.lax
操作都是 XLA 中操作的 Python 包装器;例如,这里的卷积实现由 XLA:ConvWithGeneralPadding 提供。每个 JAX 操作最终都用这些基本的 XLA 操作表示,这正是实现即时 (JIT) 编译的原因。
是否使用 JIT#
关键概念
默认情况下,JAX 一次执行一个操作,按顺序执行。
使用即时 (JIT) 编译装饰器,可以将操作序列一起优化并一次运行。
并非所有 JAX 代码都可以进行 JIT 编译,因为它要求数组形状是静态的并且在编译时已知。
所有 JAX 操作都用 XLA 表示这一事实使得 JAX 可以使用 XLA 编译器非常高效地执行代码块。
例如,考虑这个函数,它规范化 2D 矩阵的行,用 jax.numpy
操作表示
import jax.numpy as jnp
def norm(X):
X = X - X.mean(0)
return X / X.std(0)
可以使用 jax.jit
转换创建函数的即时编译版本
from jax import jit
norm_compiled = jit(norm)
此函数返回与原始函数相同的结果,达到标准浮点精度
np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)
True
但是由于编译(包括操作融合、避免分配临时数组以及许多其他技巧),在 JIT 编译的情况下,执行时间可能会快几个数量级(请注意使用 block_until_ready()
来考虑 JAX 的 异步调度)
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
286 μs ± 16.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
218 μs ± 1.45 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
也就是说,jax.jit
确实有局限性:特别是,它要求所有数组都具有静态形状。这意味着某些 JAX 操作与 JIT 编译不兼容。
例如,此操作可以在逐操作模式下执行
def get_negatives(x):
return x[x < 0]
x = jnp.array(np.random.randn(10))
get_negatives(x)
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)
但是,如果您尝试在 jit 模式下执行它,则会返回错误
jit(get_negatives)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10])
See https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
这是因为该函数生成了一个数组,其形状在编译时未知:输出的大小取决于输入数组的值,因此它与 JIT 不兼容。
JIT 机制:追踪和静态变量#
关键概念
JIT 和其他 JAX 转换通过追踪一个函数来确定其对特定形状和类型的输入的影响。
您不想被追踪的变量可以标记为静态
为了有效地使用 jax.jit
,了解其工作原理很有用。让我们在 JIT 编译的函数中放置一些 print()
语句,然后调用该函数
@jit
def f(x, y):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
result = jnp.dot(x + 1, y + 1)
print(f" result = {result}")
return result
x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
Running f():
x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace>
y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace>
result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace>
Array([0.25773212, 5.3623195 , 5.403243 ], dtype=float32)
请注意,打印语句执行了,但是它打印的是tracer对象,而不是打印我们传递给函数的数据,这些对象代表了数据。
这些 tracer 对象是 jax.jit
用来提取函数指定的操作序列的方式。基本 tracer 是代表数组的形状和 dtype 的占位符,但与值无关。然后,可以将记录的计算序列在 XLA 中有效地应用于具有相同形状和 dtype 的新输入,而无需重新执行 Python 代码。
当我们再次在匹配的输入上调用编译后的函数时,不需要重新编译,也不会打印任何内容,因为结果是在编译后的 XLA 中而不是在 Python 中计算的
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)
提取的操作序列编码在 JAX 表达式中,或简称为 jaxpr。您可以使用 jax.make_jaxpr
转换查看 jaxpr
from jax import make_jaxpr
def f(x, y):
return jnp.dot(x + 1, y + 1)
make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
c:f32[3,4] = add a 1.0
d:f32[4] = add b 1.0
e:f32[3] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] c d
in (e,) }
请注意此操作的一个结果:由于 JIT 编译是在没有关于数组内容的信息的情况下完成的,因此函数中的控制流语句不能依赖于追踪的值。例如,这会失败
@jit
def f(x, neg):
return -x if neg else x
f(1, True)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_2548/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
如果有一些变量您不想被追踪,则可以将它们标记为静态以用于 JIT 编译
from functools import partial
@partial(jit, static_argnums=(1,))
def f(x, neg):
return -x if neg else x
f(1, True)
Array(-1, dtype=int32, weak_type=True)
请注意,使用不同的静态参数调用 JIT 编译的函数会导致重新编译,因此该函数仍然按预期工作
f(1, False)
Array(1, dtype=int32, weak_type=True)
了解哪些值和操作将是静态的,哪些将被追踪是有效使用 jax.jit
的关键部分。
静态操作 vs 追踪操作#
关键概念
正如值可以是静态的或追踪的一样,操作也可以是静态的或追踪的。
静态操作在编译时在 Python 中求值;追踪操作在运行时在 XLA 中编译和求值。
对于您希望是静态的操作,请使用
numpy
;对于您希望被追踪的操作,请使用jax.numpy
。
静态值和追踪值之间的这种区别使得考虑如何保持静态值的静态性变得很重要。考虑这个函数
import jax.numpy as jnp
from jax import jit
@jit
def f(x):
return x.reshape(jnp.array(x.shape).prod())
x = jnp.ones((2, 3))
f(x)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_2548/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:
operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
from line /tmp/ipykernel_2548/1983583872.py:6 (f)
此操作失败并显示错误,指定找到了 tracer 而不是整数类型的具体值的一维序列。让我们向函数添加一些打印语句以了解为什么会发生这种情况
@jit
def f(x):
print(f"x = {x}")
print(f"x.shape = {x.shape}")
print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
# comment this out to avoid the error:
# return x.reshape(jnp.array(x.shape).prod())
f(x)
x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>
请注意,虽然 x
被追踪,但 x.shape
是一个静态值。但是,当我们在静态值上使用 jnp.array
和 jnp.prod
时,它会变成一个追踪值,此时它不能在像 reshape()
这样需要静态输入的函数中使用(回想一下:数组形状必须是静态的)。
一个有用的模式是对于应该静态的操作(即在编译时完成),使用 numpy
,对于应该追踪的操作(即在运行时编译和执行),使用 jax.numpy
。对于这个函数,它可能看起来像这样
from jax import jit
import jax.numpy as jnp
import numpy as np
@jit
def f(x):
return x.reshape((np.prod(x.shape),))
f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)
因此,JAX 程序中的一个标准约定是 import numpy as np
和 import jax.numpy as jnp
,以便两个接口都可用于更精细地控制操作是以静态方式(使用 numpy
,在编译时一次)还是以追踪方式(使用 jax.numpy
,在运行时优化)执行。