关键概念#
本节简要介绍 JAX 软件包的一些关键概念。
变换 (Transformations)#
除了用于操作数组的函数外,JAX 还包含许多作用于 JAX 函数的变换。这些变换包括:
jax.vmap():向量化变换;请参阅 自动向量化jax.grad():梯度变换;请参阅 自动微分
以及其他一些变换。变换接受一个函数作为参数,并返回一个新的变换后的函数。例如,以下是如何对简单的 SELU 函数进行 JIT 编译:
import jax
import jax.numpy as jnp
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
selu_jit = jax.jit(selu)
print(selu_jit(1.0))
1.05
通常为了方便起见,你会看到使用 Python 的装饰器语法来应用这些变换:
@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
Tracing#
变换背后的核心机制是 Tracer(追踪器)的概念。Tracer 是数组对象的抽象替代品,它们被传递给 JAX 函数,以便提取该函数所编码的操作序列。
你可以通过在变换后的 JAX 代码中打印任何数组值来观察这一点;例如:
@jax.jit
def f(x):
print(x)
return x + 1
x = jnp.arange(5)
result = f(x)
JitTracer(int32[5])
打印出的值不是数组 x,而是一个代表 x 基本属性(如其 shape 和 dtype)的 Tracer 实例。通过使用追踪值执行函数,JAX 可以在这些操作实际执行之前确定该函数编码的操作序列:诸如 jit()、vmap() 和 grad() 等变换,随后可以将此输入操作序列映射为变换后的操作序列。
静态操作与追踪操作:正如值可以是静态的或经过追踪的,操作也可以是静态的或经过追踪的。静态操作在 Python 编译时进行求值;追踪操作在 XLA 中进行编译并在运行时进行求值。
有关更多详细信息,请参阅 追踪 (Tracing)。
Jaxprs#
JAX 有其自己的操作序列中间表示形式,称为 jaxpr。jaxpr(JAX exPRession 的缩写)是函数式程序的一种简单表示形式,由一系列基本操作组成。
例如,考虑我们上面定义的 selu 函数:
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
我们可以使用 jax.make_jaxpr() 工具,根据特定的输入将此函数转换为 jaxpr:
x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)
{ lambda ; a:f32[5]. let
b:bool[5] = gt a 0.0:f32[]
c:f32[5] = exp a
d:f32[5] = mul 1.6699999570846558:f32[] c
e:f32[5] = sub d 1.6699999570846558:f32[]
f:f32[5] = jit[
name=_where
jaxpr={ lambda ; b:bool[5] a:f32[5] e:f32[5]. let
f:f32[5] = select_n b e a
in (f,) }
] b a e
g:f32[5] = mul 1.0499999523162842:f32[] f
in (g,) }
将其与 Python 函数定义进行比较,我们可以看到它精确地编码了该函数所代表的操作序列。我们稍后将在 JAX 内部机制:jaxpr 语言 中更深入地探讨 jaxpr。
Pytrees#
JAX 函数和变换从根本上是在数组上进行操作的,但在实践中,编写能够处理数组集合的代码通常更方便:例如,神经网络可能会将其参数组织在一个带有有意义键的数组字典中。JAX 并不逐个处理此类结构,而是依赖 pytree 抽象来以统一的方式处理这些集合。
以下是一些可以被视为 pytree 的对象示例:
# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
[1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parameters
from typing import NamedTuple
class Params(NamedTuple):
a: int
b: float
params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]
JAX 拥有许多用于处理 PyTree 的通用工具;例如,函数 jax.tree.map() 可用于将函数映射到树中的每个叶子节点,而 jax.tree.reduce() 可用于跨树中的叶子节点执行归约操作。
你可以通过 Pytrees 教程了解更多信息。
JAX API 分层:NumPy、lax 和 XLA#
所有 JAX 操作均通过 XLA(加速线性代数编译器)中的操作来实现。如果你查看 jax.numpy 的源代码,你会发现所有操作最终都会被转换为 jax.lax 中定义的函数。虽然 jax.numpy 是一个提供熟悉接口的高级封装器,但你可以将 jax.lax 看作是一个用于处理多维数组的更严格、但通常更强大的低级 API。
例如,虽然 jax.numpy 会隐式提升参数以允许混合数据类型之间的操作,但 jax.lax 不会:
import jax.numpy as jnp
jnp.add(1, 1.0) # jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0) # jax.lax API requires explicit type promotion.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[10], line 2
1 from jax import lax
----> 2 lax.add(1, 1.0) # jax.lax API requires explicit type promotion.
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/lax/lax.py:1164, in add(x, y)
1144 r"""Elementwise addition: :math:`x + y`.
1145
1146 This function lowers directly to the `stablehlo.add`_ operation.
(...) 1161 .. _stablehlo.add: https://openxla.cn/stablehlo/spec#add
1162 """
1163 x, y = core.standard_insert_pvary(x, y)
-> 1164 return add_p.bind(x, y)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:658, in Primitive.bind(self, *args, **params)
656 trace_ctx.set_trace(eval_trace)
657 try:
--> 658 return self.bind_with_trace(prev_trace, args, avals, params)
659 finally:
660 trace_ctx.set_trace(prev_trace)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:666, in Primitive.bind_with_trace(self, trace, args, avals, params)
664 with set_current_trace(trace):
665 return self.to_lojax(*args, **params) # pyrefly: ignore[not-callable]
--> 666 return trace.process_primitive(self, args, params)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1224, in EvalTrace.process_primitive(self, primitive, args, params)
1222 args = map(full_lower, args)
1223 check_eval_args(args)
-> 1224 return primitive.impl(*args, **params)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/dispatch.py:90, in apply_primitive(prim, *args, **params)
88 prev = config.disable_jit.swap_local(False)
89 try:
---> 90 outs = fun(*args)
91 finally:
92 config.disable_jit.set_local(prev)
[... skipping hidden 15 frame]
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/lax/lax.py:8956, in check_same_dtypes(name, *avals)
8954 equiv = _JNP_FUNCTION_EQUIVALENTS[name]
8955 msg += f" (Tip: jnp.{equiv} is a similar function that does automatic type promotion on inputs)."
-> 8956 raise TypeError(msg.format(name, ", ".join(str(a.dtype) for a in avals)))
TypeError: lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).
如果直接使用 jax.lax,在这种情况下你必须显式进行类型提升:
lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)
除了这种严格性之外,jax.lax 还为一些比 NumPy 支持的更通用的操作提供了高效的 API。
例如,考虑一个一维卷积,可以用 NumPy 按如下方式表示:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)
在底层,该 NumPy 操作被转换为由 lax.conv_general_dilated 实现的更通用的卷积操作。
from jax import lax
result = lax.conv_general_dilated(
x.reshape(1, 1, 3).astype(float), # note: explicit promotion
y.reshape(1, 1, 10),
window_strides=(1,),
padding=[(len(y) - 1, len(y) - 1)]) # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)
这是一个批量卷积操作,旨在针对深度神经网络中常用的卷积类型进行优化。它需要更多的样板代码,但比 NumPy 提供的卷积更加灵活和可扩展(有关 JAX 卷积的更多详细信息,请参阅 JAX 中的卷积)。
归根结底,所有 jax.lax 操作都是 XLA 操作的 Python 封装器;例如,此处的卷积实现由 XLA:ConvWithGeneralPadding 提供。每个 JAX 操作最终都表示为这些基本的 XLA 操作,这正是实现即时 (JIT) 编译的基础。