核心概念#

本节简要介绍 JAX 包的一些核心概念。

转换#

除了操作数组的函数外,JAX 还包含许多对 JAX 函数进行操作的转换。其中包括:

以及其他一些。转换接受一个函数作为参数,并返回一个新的转换后的函数。例如,以下是 JIT 编译一个简单的 SELU 函数的方法:

import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.0))
1.05

为了方便起见,您通常会看到使用 Python 的装饰器语法来应用转换:

@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

追踪#

转换背后的魔法是Tracer的概念。Tracer 是数组对象的抽象替身,并被传递给 JAX 函数以提取函数编码的操作序列。

您可以通过在转换后的 JAX 代码中打印任何数组值来看到这一点;例如:

@jax.jit
def f(x):
  print(x)
  return x + 1

x = jnp.arange(5)
result = f(x)
JitTracer<int32[5]>

打印出的值不是数组 x,而是表示 x 基本属性(例如其 shapedtype)的 Tracer 实例。通过使用追踪值执行函数,JAX 可以在操作实际执行之前确定函数编码的操作序列:然后像 jit()vmap()grad() 这样的转换可以将此输入操作序列映射到转换后的操作序列。

静态操作与追踪操作#

就像值可以是静态的或追踪的,操作也可以是静态的或追踪的。静态操作在 Python 中编译时求值;追踪操作在 XLA 中运行时编译和求值。

静态值和追踪值之间的这种区别使得思考如何保持静态值静态变得很重要。考虑以下函数:

import jax.numpy as jnp
from jax import jit

@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 9
      6   return x.reshape(jnp.array(x.shape).prod())
      8 x = jnp.ones((2, 3))
----> 9 f(x)

    [... skipping hidden 13 frame]

Cell In[4], line 6, in f(x)
      4 @jit
      5 def f(x):
----> 6   return x.reshape(jnp.array(x.shape).prod())

    [... skipping hidden 2 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:457, in _compute_newshape(arr, newshape)
    455 except:
    456   newshape = [newshape]
--> 457 newshape = core.canonicalize_shape(newshape)  # type: ignore[arg-type]
    458 neg1s = [i for i, d in enumerate(newshape) if type(d) is int and d == -1]
    459 if len(neg1s) > 1:

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1924, in canonicalize_shape(shape, context)
   1922 except TypeError:
   1923   pass
-> 1924 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got [JitTracer<int32[]>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_1892/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[] = reduce_prod[axes=(0,)] b
    from line /tmp/ipykernel_1892/1983583872.py:6:19 (f)

这将失败,并出现错误,指出找到了一个追踪器,而不是一个具体整数值的一维序列。让我们在函数中添加一些 print 语句来理解为什么会发生这种情况:

@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)
x = JitTracer<float32[2,3]>
x.shape = (2, 3)
jnp.array(x.shape).prod() = JitTracer<int32[]>

请注意,尽管 x 是追踪的,但 x.shape 是一个静态值。然而,当我们在这个静态值上使用 jnp.arrayjnp.prod 时,它变成了一个追踪值,此时它不能用于像 reshape() 这样需要静态输入的函数(回想一下:数组的 shape 必须是静态的)。

一个有用的模式是使用 numpy 进行应该是静态的操作(即在编译时完成),并使用 jax.numpy 进行应该是追踪的操作(即在运行时编译和执行)。对于这个函数,它可能看起来像这样:

from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)

因此,JAX 程序中的标准约定是 import numpy as npimport jax.numpy as jnp,以便两个接口都可用于更精细地控制操作是以静态方式(使用 numpy,在编译时执行一次)还是以追踪方式(使用 jax.numpy,在运行时优化)执行。

Jaxprs#

JAX 有自己的操作序列中间表示,称为jaxpr。Jaxpr(JAX exPRession 的缩写)是函数式程序的简单表示,由一系列原始操作组成。

例如,考虑我们上面定义的 selu 函数:

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

我们可以使用 jax.make_jaxpr() 实用程序,根据特定输入将此函数转换为 jaxpr:

x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)
{ lambda ; a:f32[5]. let
    b:bool[5] = gt a 0.0:f32[]
    c:f32[5] = exp a
    d:f32[5] = mul 1.67:f32[] c
    e:f32[5] = sub d 1.67:f32[]
    f:f32[5] = jit[
      name=_where
      jaxpr={ lambda ; b:bool[5] a:f32[5] e:f32[5]. let
          f:f32[5] = select_n b e a
        in (f,) }
    ] b a e
    g:f32[5] = mul 1.05:f32[] f
  in (g,) }

与 Python 函数定义相比,我们看到它编码了函数表示的精确操作序列。我们将在JAX 内部:jaxpr 语言中更深入地探讨 jaxprs。

Pytrees#

JAX 函数和转换主要操作数组,但在实践中,使用数组集合编写代码很方便:例如,神经网络可能会将其参数组织成一个带有有意义键的数组字典。JAX 不会逐个处理这些结构,而是依赖于pytree抽象来统一处理这些集合。

以下是一些可以视为 pytree 的对象示例:

# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parameters
from typing import NamedTuple

class Params(NamedTuple):
  a: int
  b: float

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]

JAX 有许多用于处理 PyTrees 的通用实用程序;例如,函数 jax.tree.map() 可用于将函数映射到树中的每个叶子,而 jax.tree.reduce() 可用于对树中的叶子应用归约。

您可以在使用 pytrees教程中了解更多信息。

JAX API 分层:NumPy、lax 和 XLA#

所有 JAX 操作都是根据 XLA(加速线性代数编译器)中的操作实现的。如果您查看 jax.numpy 的源代码,您会发现所有操作最终都以 jax.lax 中定义的函数形式表达。虽然 jax.numpy 是一个提供熟悉接口的高级包装器,但您可以将 jax.lax 视为一个更严格但通常更强大的低级 API,用于处理多维数组。

例如,jax.numpy 会隐式提升参数以允许混合数据类型之间的操作,而 jax.lax 不会:

import jax.numpy as jnp
jnp.add(1, 1.0)  # jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0)  # jax.lax API requires explicit type promotion.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[13], line 2
      1 from jax import lax
----> 2 lax.add(1, 1.0)  # jax.lax API requires explicit type promotion.

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/lax/lax.py:1194, in add(x, y)
   1174 r"""Elementwise addition: :math:`x + y`.
   1175 
   1176 This function lowers directly to the `stablehlo.add`_ operation.
   (...)   1191 .. _stablehlo.add: https://openxla.org/stablehlo/spec#add
   1192 """
   1193 x, y = core.standard_insert_pvary(x, y)
-> 1194 return add_p.bind(x, y)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:576, in Primitive.bind(self, *args, **params)
    574 def bind(self, *args, **params):
    575   args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 576   return self._true_bind(*args, **params)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:592, in Primitive._true_bind(self, *args, **params)
    590 trace_ctx.set_trace(eval_trace)
    591 try:
--> 592   return self.bind_with_trace(prev_trace, args, params)
    593 finally:
    594   trace_ctx.set_trace(prev_trace)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:602, in Primitive.bind_with_trace(self, trace, args, params)
    599   with set_current_trace(trace):
    600     return self.to_lojax(*args, **params)  # type: ignore
--> 602 return trace.process_primitive(self, args, params)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/core.py:1126, in EvalTrace.process_primitive(self, primitive, args, params)
   1124 args = map(full_lower, args)
   1125 check_eval_args(args)
-> 1126 return primitive.impl(*args, **params)

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/dispatch.py:91, in apply_primitive(prim, *args, **params)
     89 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     90 try:
---> 91   outs = fun(*args)
     92 finally:
     93   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 26 frame]

File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/lax/lax.py:8747, in check_same_dtypes(name, *avals)
   8745   equiv = _JNP_FUNCTION_EQUIVALENTS[name]
   8746   msg += f" (Tip: jnp.{equiv} is a similar function that does automatic type promotion on inputs)."
-> 8747 raise TypeError(msg.format(name, ", ".join(str(a.dtype) for a in avals)))

TypeError: lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).

如果直接使用 jax.lax,在这种情况下您必须显式地进行类型提升:

lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)

除了这种严格性之外,jax.lax 还提供了一些比 NumPy 支持的更通用操作的高效 API。

例如,考虑一个 1D 卷积,它可以在 NumPy 中这样表示:

x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

在底层,这个 NumPy 操作被转换为一个由 lax.conv_general_dilated 实现的更通用的卷积:

from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

这是一个批处理卷积操作,旨在对深度神经网络中常用的卷积类型高效。它需要更多的样板代码,但比 NumPy 提供的卷积更灵活和可扩展(有关 JAX 卷积的更多详细信息,请参阅JAX 中的卷积)。

从本质上讲,所有 jax.lax 操作都是 XLA 中操作的 Python 包装器;例如,这里的卷积实现由 XLA:ConvWithGeneralPadding 提供。每个 JAX 操作最终都以这些基本的 XLA 操作的形式表达,这使得即时 (JIT) 编译成为可能。