关键概念#

本节简要介绍 JAX 包的一些关键概念。

变换#

除了用于操作数组的函数外,JAX 还包含许多用于操作 JAX 函数的 变换。这些变换包括:

以及其他一些。变换接受一个函数作为参数,并返回一个新的变换后的函数。例如,你可以这样 JIT 编译一个简单的 SELU 函数:

import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.0))
1.05

为了方便起见,你经常会看到使用 Python 的装饰器语法来应用变换。

@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

追踪#

变换背后的魔法是 Tracer 的概念。Tracers 是数组对象的抽象占位符,它们被传递给 JAX 函数,以提取函数所编码的操作序列。

你可以通过打印变换后的 JAX 代码中的任何数组值来看到这一点;例如:

@jax.jit
def f(x):
  print(x)
  return x + 1

x = jnp.arange(5)
result = f(x)
JitTracer<int32[5]>

打印的值不是数组 x,而是一个 Tracer 实例,它代表了 x 的基本属性,例如它的 shapedtype。通过使用追踪值执行函数,JAX 可以在操作实际执行之前确定函数所编码的操作序列:然后,像 jit()vmap()grad() 这样的变换就可以将这个输入操作序列映射到一个变换后的操作序列。

静态 vs 追踪操作:就像值可以是静态的或追踪的,操作也可以是静态的或追踪的。静态操作在 Python 的编译时进行评估;追踪操作在 XLA 的运行时进行编译和评估。

有关更多详细信息,请参阅 追踪

Jaxprs#

JAX 有自己的操作序列中间表示,称为 jaxpr。Jaxpr(JAX exPRession 的缩写)是一个函数式程序的简单表示,由一系列 原始 操作组成。

例如,考虑我们上面定义的 selu 函数:

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

给定一个特定的输入,我们可以使用 jax.make_jaxpr() 工具将此函数转换为 jaxpr:

x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)
{ lambda ; a:f32[5]. let
    b:bool[5] = gt a 0.0:f32[]
    c:f32[5] = exp a
    d:f32[5] = mul 1.6699999570846558:f32[] c
    e:f32[5] = sub d 1.6699999570846558:f32[]
    f:f32[5] = jit[
      name=_where
      jaxpr={ lambda ; b:bool[5] a:f32[5] e:f32[5]. let
          f:f32[5] = select_n b e a
        in (f,) }
    ] b a e
    g:f32[5] = mul 1.0499999523162842:f32[] f
  in (g,) }

将此与 Python 函数定义进行比较,我们可以看到它编码了函数所代表的操作的精确序列。我们将在 JAX 内部:jaxpr 语言 中更深入地探讨 jaxprs。

Pytrees#

JAX 函数和变换从根本上操作数组,但在实践中,编写处理数组集合的代码很方便:例如,神经网络可能会使用带有有意义键的数组字典来组织其参数。JAX 不会逐个处理这些结构,而是依赖 pytree 抽象来统一处理这些集合。

以下是一些可以被视为 pytrees 的对象示例:

# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parameters
from typing import NamedTuple

class Params(NamedTuple):
  a: int
  b: float

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]

JAX 有许多通用的实用程序可用于处理 PyTrees;例如,函数 jax.tree.map() 可用于将函数映射到树中的每个叶节点,而 jax.tree.reduce() 可用于在树的叶节点之间应用规约。

你可以在 使用 pytrees 教程中了解更多信息。

JAX API 分层:NumPy, lax & XLA#

所有 JAX 操作都基于 XLA(加速线性代数编译器)中的操作实现。如果你查看 jax.numpy 的源代码,你会发现所有操作最终都表示为在 jax.lax 中定义的功能。虽然 jax.numpy 是一个提供熟悉接口的高级包装器,但你可以将 jax.lax 视为一个更严格但通常更强大的低级 API,用于处理多维数组。

例如,虽然 jax.numpy 会隐式提升参数以允许混合数据类型之间的操作,但 jax.lax 则不会:

import jax.numpy as jnp
jnp.add(1, 1.0)  # jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0)  # jax.lax API requires explicit type promotion.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[10], line 2
      1 from jax import lax
----> 2 lax.add(1, 1.0)  # jax.lax API requires explicit type promotion.

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/lax/lax.py:1199, in add(x, y)
   1179 r"""Elementwise addition: :math:`x + y`.
   1180 
   1181 This function lowers directly to the `stablehlo.add`_ operation.
   (...)   1196 .. _stablehlo.add: https://openxla.org/stablehlo/spec#add
   1197 """
   1198 x, y = core.standard_insert_pvary(x, y)
-> 1199 return add_p.bind(x, y)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:632, in Primitive.bind(self, *args, **params)
    630 def bind(self, *args, **params):
    631   args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 632   return self._true_bind(*args, **params)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:648, in Primitive._true_bind(self, *args, **params)
    646 trace_ctx.set_trace(eval_trace)
    647 try:
--> 648   return self.bind_with_trace(prev_trace, args, params)
    649 finally:
    650   trace_ctx.set_trace(prev_trace)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:660, in Primitive.bind_with_trace(self, trace, args, params)
    658     with set_current_trace(trace):
    659       return self.to_lojax(*args, **params)  # type: ignore
--> 660   return trace.process_primitive(self, args, params)
    661 trace.process_primitive(self, args, params)  # may raise lojax error
    662 raise Exception(f"couldn't apply typeof to args: {args}")

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1189, in EvalTrace.process_primitive(self, primitive, args, params)
   1187 args = map(full_lower, args)
   1188 check_eval_args(args)
-> 1189 return primitive.impl(*args, **params)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/dispatch.py:94, in apply_primitive(prim, *args, **params)
     92 prev = config.disable_jit.swap_local(False)
     93 try:
---> 94   outs = fun(*args)
     95 finally:
     96   config.disable_jit.set_local(prev)

    [... skipping hidden 26 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/lax/lax.py:8878, in check_same_dtypes(name, *avals)
   8876   equiv = _JNP_FUNCTION_EQUIVALENTS[name]
   8877   msg += f" (Tip: jnp.{equiv} is a similar function that does automatic type promotion on inputs)."
-> 8878 raise TypeError(msg.format(name, ", ".join(str(a.dtype) for a in avals)))

TypeError: lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).

如果直接使用 jax.lax,在这种情况下你将不得不显式地进行类型提升。

lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)

除了这种严格性之外,jax.lax 还为一些比 NumPy 支持的更通用的操作提供了高效的 API。

例如,考虑一个一维卷积,可以用 NumPy 这样表示:

x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

在底层,这个 NumPy 操作被翻译成一个由 lax.conv_general_dilated 实现的更通用的卷积:

from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

这是一个批处理卷积操作,旨在高效处理深度神经网络中常用的卷积类型。它需要更多的样板代码,但比 NumPy 提供的卷积更灵活、更具可伸缩性(有关 JAX 中卷积的更多详细信息,请参阅 JAX 中的卷积)。

本质上,所有 jax.lax 操作都是 XLA 操作的 Python 包装器;例如,这里的卷积实现由 XLA:ConvWithGeneralPadding 提供。每个 JAX 操作最终都表示为这些基本的 XLA 操作,这就是实现即时 (JIT) 编译的原因。