如何在 JAX 中思考#

Open in Colab Open in Kaggle

JAX 提供了一个简单而强大的 API 用于编写加速数值代码,但有效地使用 JAX 有时需要额外的考虑。本文档旨在帮助构建对 JAX 操作方式的从头到尾的理解,以便您可以更有效地使用它。

JAX 与 NumPy#

关键概念

  • JAX 提供了一个受 NumPy 启发的接口,以方便使用。

  • 通过鸭子类型,JAX 数组通常可以用作 NumPy 数组的直接替换。

  • 与 NumPy 数组不同,JAX 数组始终是不可变的。

NumPy 提供了一个众所周知的强大 API 用于处理数值数据。为了方便起见,JAX 提供了 jax.numpy,它与 numpy API 非常相似,并提供了进入 JAX 的便捷途径。几乎所有可以使用 numpy 完成的操作都可以使用 jax.numpy 完成。

import matplotlib.pyplot as plt
import numpy as np

x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np);
../_images/ed74117ce798d02f04559155709be03bef63cfa850e6af47b918884ed471961f.png
import jax.numpy as jnp

x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);
../_images/7566bdaf92d2f9beda43e4d3ddee916a69d996ec12b671de077d49428fd54fd2.png

除了将 np 替换为 jnp 外,代码块是相同的,结果也是一样的。正如我们所看到的,JAX 数组通常可以直接替代 NumPy 数组用于绘图等操作。

数组本身是作为不同的 Python 类型实现的。

type(x_np)
numpy.ndarray
type(x_jnp)
jaxlib.xla_extension.ArrayImpl

Python 的 鸭子类型 允许 JAX 数组和 NumPy 数组在许多地方互换使用。

然而,JAX 数组和 NumPy 数组之间存在一个重要的区别:JAX 数组是不可变的,这意味着一旦创建,其内容就无法更改。

以下是在 NumPy 中修改数组的示例

# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)
[10  1  2  3  4  5  6  7  8  9]

JAX 中的等效操作会导致错误,因为 JAX 数组是不可变的

%xmode minimal
Exception reporting mode: Minimal
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.net.cn/en/latest/_autosummary/jax.numpy.ndarray.at.html

为了更新单个元素,JAX 提供了一种 索引更新语法,该语法会返回一个更新后的副本

y = x.at[0].set(10)
print(x)
print(y)
[0 1 2 3 4 5 6 7 8 9]
[10  1  2  3  4  5  6  7  8  9]

NumPy、lax 和 XLA:JAX API 分层#

关键概念

  • jax.numpy 是一个高级包装器,提供了一个熟悉的接口。

  • jax.lax 是一个更低级的 API,它更加严格,通常功能更强大。

  • 所有 JAX 操作都是根据 XLA(加速线性代数编译器)中的操作实现的。

如果你查看 jax.numpy 的源代码,你会发现所有操作最终都用 jax.lax 中定义的函数来表示。你可以将 jax.lax 视为一个更严格但通常功能更强大的用于处理多维数组的 API。

例如,虽然 jax.numpy 会隐式提升参数以允许混合数据类型之间的运算,但 jax.lax 不会。

import jax.numpy as jnp
jnp.add(1, 1.0)  # jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0)  # jax.lax API requires explicit type promotion.
ValueError: Cannot lower jaxpr with verifier errors:
	op requires the same element type for all operands and results
		at loc("jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_2935/3435837498.py":2:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at callsite("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0) at callsite("_run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3130:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3075:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/zmqshell.py":549:0) at callsite("do_execute"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/ipkernel.py":449:0) at "execute_request"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/kernelbase.py":778:0))))))))))))
Define JAX_DUMP_IR_TO to dump the module.

如果直接使用 jax.lax,则必须在这些情况下显式进行类型提升

lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)

除了这种严格性之外,jax.lax 还为一些比 NumPy 支持的更通用的操作提供了高效的 API。

例如,考虑一个一维卷积,它可以用 NumPy 这样表示

x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

在后台,此 NumPy 操作被转换为由 lax.conv_general_dilated 实现的更通用的卷积。

from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

这是一个批处理卷积操作,旨在针对深度神经网络中经常使用的卷积类型进行优化。它需要更多的样板代码,但比 NumPy 提供的卷积更灵活和可扩展(有关 JAX 卷积的更多详细信息,请参阅 JAX 中的卷积)。

从本质上讲,所有 jax.lax 操作都是 XLA 中操作的 Python 包装器;例如,此处卷积实现由 XLA:ConvWithGeneralPadding 提供。每个 JAX 操作最终都用这些基本的 XLA 操作表示,这正是 JIT(即时)编译能够实现的原因。

是否使用 JIT#

关键概念

  • 默认情况下,JAX 按顺序一次执行一个操作。

  • 使用即时 (JIT) 编译装饰器,可以将操作序列一起优化并一次运行。

  • 并非所有 JAX 代码都可以进行 JIT 编译,因为它要求数组形状在编译时是静态的且已知的。

所有 JAX 操作都用 XLA 表示这一事实允许 JAX 使用 XLA 编译器非常高效地执行代码块。

例如,考虑以下函数,该函数使用 jax.numpy 操作表示对二维矩阵的行进行归一化

import jax.numpy as jnp

def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)

可以使用 jax.jit 转换创建函数的即时编译版本

from jax import jit
norm_compiled = jit(norm)

此函数返回与原始函数相同的结果,精确到标准浮点精度

np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)
True

但是由于编译(包括操作融合、避免分配临时数组以及其他许多技巧),在 JIT 编译的情况下,执行时间可能会快几个数量级(请注意使用 block_until_ready() 来考虑 JAX 的 异步调度)。

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
942 μs ± 6.93 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
512 μs ± 2.88 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

也就是说,jax.jit 确实有一些限制:特别是,它要求所有数组都具有静态形状。这意味着某些 JAX 操作与 JIT 编译不兼容。

例如,此操作可以在逐操作模式下执行

def get_negatives(x):
  return x[x < 0]

x = jnp.array(np.random.randn(10))
get_negatives(x)
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)

但是,如果尝试在 jit 模式下执行它,则会返回错误

jit(get_negatives)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

这是因为该函数生成一个形状在编译时未知的数组:输出的大小取决于输入数组的值,因此它与 JIT 不兼容。

JIT 机制:跟踪和静态变量#

关键概念

  • JIT 和其他 JAX 转换通过跟踪函数来确定其对特定形状和类型的输入的影响。

  • 不想被跟踪的变量可以标记为静态

要有效地使用 jax.jit,了解其工作原理很有用。让我们在 JIT 编译的函数中添加一些 print() 语句,然后调用该函数

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
Array([0.25773212, 5.3623195 , 5.403243  ], dtype=float32)

请注意,print 语句会执行,但它不会打印我们传递给函数的数据,而是打印跟踪器对象作为其替代。

这些跟踪器对象是 jax.jit 用于提取函数指定的操作序列的内容。基本跟踪器是占位符,用于编码数组的**形状**和**数据类型**,但与值无关。然后,可以在 XLA 中高效地将此记录的计算序列应用于具有相同形状和数据类型的新的输入,而无需重新执行 Python 代码。

当我们再次对匹配的输入调用已编译函数时,不需要重新编译,也不会打印任何内容,因为结果是在已编译的 XLA 中计算的,而不是在 Python 中计算的。

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)

提取的操作序列编码在 JAX 表达式或简称为jaxpr 中。可以使用 jax.make_jaxpr 转换查看 jaxpr

from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) }

请注意由此产生的一点:因为 JIT 编译是在没有数组内容信息的情况下完成的,所以函数中的控制流语句不能依赖于跟踪的值。例如,这会失败

@jit
def f(x, neg):
  return -x if neg else x

f(1, True)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_2935/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

如果有一些变量你不想被跟踪,可以将其标记为静态以用于 JIT 编译

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True)
Array(-1, dtype=int32, weak_type=True)

请注意,使用不同的静态参数调用 JIT 编译的函数会导致重新编译,因此该函数仍按预期工作

f(1, False)
Array(1, dtype=int32, weak_type=True)

了解哪些值和操作将是静态的,哪些将被跟踪,是有效使用 jax.jit 的关键部分。

静态操作与跟踪操作#

关键概念

  • 就像值可以是静态的或被跟踪的一样,操作也可以是静态的或被跟踪的。

  • 静态操作在 Python 中编译时进行评估;跟踪操作在 XLA 中编译和运行时进行评估。

  • 对于希望保持静态的操作,请使用 numpy;对于希望被跟踪的操作,请使用 jax.numpy

静态值和跟踪值之间的这种区别使得思考如何保持静态值静态变得很重要。考虑以下函数

import jax.numpy as jnp
from jax import jit

@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_2935/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /tmp/ipykernel_2935/1983583872.py:6 (f)

这会失败并出现错误,指出找到了跟踪器而不是整数类型具体值的 1D 序列。让我们在函数中添加一些 print 语句来了解为什么会发生这种情况

@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)
x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>

请注意,虽然 x 被跟踪,但 x.shape 是一个静态值。但是,当我们在此静态值上使用 jnp.arrayjnp.prod 时,它会变成一个跟踪值,此时它不能在像 reshape() 这样的需要静态输入的函数中使用(回想一下:数组形状必须是静态的)。

一个有用的模式是使用 numpy 进行应为静态的操作(即在编译时完成),并使用 jax.numpy 进行应被跟踪的操作(即在运行时编译和执行)。对于此函数,它可能如下所示

from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)

因此,JAX 程序中的一个标准约定是 import numpy as npimport jax.numpy as jnp,以便这两个接口都可用,以便更精细地控制操作是在静态方式(使用 numpy,在编译时一次)还是跟踪方式(使用 jax.numpy,在运行时优化)下执行。