JEP 28661:支持 __jax_array__ 协议#
@jakevdp, 2025年5月
用户偶尔会提出一项需求,即定义可与 JAX API 协作的自定义类数组对象。JAX 目前通过在自定义对象上定义 __jax_array__ 方法,部分实现了这一机制。这本意并非作为一种承载核心功能的公共 API(参见 #4725 的讨论),但它已成为 Keras 和 Flax 等软件包不可或缺的一部分,这些包明确记录了将其自定义数组对象与 JAX 函数配合使用的能力。本 JEP 旨在为 __jax_array__ 协议提供完整且正式的文档化支持。
数组可扩展性级别#
对 JAX 数组可扩展性的需求主要有几种类型:
一级可扩展性:多态输入#
我将“一级”可扩展性定义为 JAX API 接受多态输入的需求。即用户期望实现如下行为:
class CustomArray:
data: numpy.ndarray
...
x = CustomArray(np.arange(5))
result = jnp.sin(x) # Converts `x` to JAX array and returns a JAX array
在此可扩展性模型下,JAX 函数将接受 CustomArray 对象作为输入,并在计算时将其隐式转换为 jax.Array 对象。这类似于 NumPy 通过 __array__ 方法,以及 JAX(在多数但非全部情况下)通过 __jax_array__ 方法所提供的功能。
这是 flax.nnx 及其他项目维护者所请求的可扩展性模式。当前的实现也已在 JAX 内部用于处理符号维度的情况。
二级可扩展性:多态输出#
我将“二级”可扩展性定义为不仅要求 JAX API 接受多态输入,还要将输出包装回与输入相同的类。即用户期望实现如下行为:
class CustomArray:
data: numpy.ndarray
...
x = CustomArray(np.arange(5))
result = jnp.sin(x) # returns a new CustomArray
在此可扩展性模型下,JAX 函数不仅接受自定义对象作为输入,还拥有某种协议来确定如何将输出正确地重新包装为同一类。在 NumPy 中,这种功能通过特殊的 __array_ufunc__、__array_wrap__ 和 __array_function__ 协议在不同程度上实现,允许用户定义的对象自定义 NumPy API 函数处理任意输入并将输入类型映射到输出的方式。JAX 目前还没有等同于 NumPy 这些接口的功能。
这是 keras 等维护者所请求的可扩展性模式。
三级可扩展性:子类化 Array#
我将“三级”可扩展性定义为用户希望 JAX 数组对象本身可以被子类化。NumPy 提供了一些允许此操作的 API(参见 Subclassing ndarray),但考虑到 JAX 需要通过追踪(tracing)抽象地表示数组对象,这种方法在 JAX 中需要更深思熟虑。
这种可扩展性模式偶尔会被想要为 JAX 数组添加特殊元数据(如计量单位)的用户所要求。
概要#
为了本提案的目的,我们将坚持最简单的“一级”可扩展性模型。提议的接口是目前许多 JAX API 非统一支持的 __jax_array__ 方法。其用法如下所示:
import jax
import jax.numpy as jnp
import numpy as np
class CustomArray:
data: np.ndarray
def __init__(self, data: np.ndarray):
self.data = data
def __jax_array__(self) -> jax.Array:
return jnp.asarray(self.data)
arr = CustomArray(np.arange(5))
result = jnp.multiply(arr, 2)
print(repr(result))
# Array([0, 2, 4, 6, 8], dtype=int32)
我们可能会在未来重新探讨其他可扩展性级别。
设计挑战#
JAX 在此类可扩展性方面提出了一些有趣的设计挑战,这些挑战此前尚未得到充分探索。我们将在此逐一讨论。
__jax_array__ 与 PyTree 展平的优先级#
JAX 已经拥有一种支持注册自定义对象的机制,即 PyTree 注册(参见 自定义 PyTree 节点)。如果我们同时支持 jax_array,哪一个应该优先?
更具体地说,这段代码的结果应该是什么?
@jax.jit
def f(x):
print("is JAX array:", isinstance(x, jax.Array))
f(CustomArray(...))
如果我们选择在 JIT 边界优先处理 __jax_array__,那么此函数的输出将是:
is JAX array: True
也就是说,在 JIT 边界,CustomArray 对象会被转换为 __jax_array__,其形状和数据类型将用于为该函数构建一个标准的 JAX 追踪器(tracer)。
如果我们选择在 JIT 边界优先处理 PyTree 展平,那么此函数的输出将是:
type(x)=CustomArray
也就是说,在 JIT 边界,CustomArray 对象会被展平,然后在传递给 JIT 编译函数进行追踪之前重新展开。如果 CustomArray 已注册为 PyTree,它通常会包含被追踪的数组作为其属性;当 x 被传递给任何支持 __jax_array__ 的 JAX API 时,这些追踪属性将根据该方法中指定的逻辑转换为单个被追踪的数组。
对于 vmap 和 grad 等其他转换在遇到自定义对象时的工作方式,这里有更深远的影响:例如,如果我们优先考虑 PyTree 展平,vmap 将对自定义对象的展平内容维度进行操作;而如果我们优先考虑 __jax_array__,vmap 将对转换后的数组维度进行操作。
这在 JIT 不变性方面也会产生影响:考虑如下函数:
def f(x):
if isinstance(x, CustomArray):
return x.custom_method()
else:
# do something else
...
result1 = f(x)
result2 = jax.jit(f)(x)
如果 jit 通过 PyTree 展平来处理 x,那么对于明确定义的展平规则,结果应该是一致的。如果 jit 通过 __jax_array__ 处理 x,结果就会不同,因为在函数经过 JIT 编译的版本中,x 不再是 CustomArray。
概要#
自 JAX v0.6.0 起,变换在可用时会优先考虑 __jax_array__。这种现状可能导致围绕 JIT 不变性缺失的困惑;并且在实践中,当前的实现会在自动微分(AD)中导致微妙的 Bug,即前向和反向传递处理输入的方式不一致。
由于 PyTree 可扩展性机制已经存在于自定义变换中,最直接的做法是让变换仅通过该机制作用:即我们提议在抽象化(abstractification)过程中移除对 __jax_array__ 的解析。这种方法将保留转换过程中的对象标识,并给予用户最大的灵活性。如果用户想要选择数组转换语义,可以通过 jnp.asarray 显式转换输入来实现,这会触发 __jax_array__ 协议。
哪些 API 应该支持 __jax_array__?#
JAX 拥有多个不同级别的 API,从显式原语绑定(如 jax.lax.add_p.bind(x, y)),到 jax.lax API(如 jax.lax.add(x, y)),再到 jax.numpy API(如 jax.numpy.add(x, y))。这些 API 类别中,哪些应该处理通过 __jax_array__ 进行的隐式转换?
为了限制变更范围和测试需求,我建议 __jax_array__ 仅在 jax.numpy API 中显式支持:毕竟,它受到 NumPy 软件包所支持的 __array__ 协议的启发。如果需要,我们将来总是可以将其扩展到 jax.lax API。
这与软件包的当前状态一致,即 __jax_array__ 的处理主要位于 jax.numpy API 所使用的输入验证工具中。
实现#
基于这些设计选择,我们计划如下实施:
为
jax.numpy添加运行时支持:这可能是最容易的部分,因为大多数jax.numpy函数都使用一个通用的内部工具(ensure_arraylike)来验证输入并将其转换为数组。该工具已经支持__jax_array__,因此大多数jax.numpyAPI 已经合规。增加测试覆盖率:为了确保各 API 的合规性,我们应该添加一个新的测试框架,使用自定义输入调用每个
jax.numpyAPI,并验证其行为是否正确。废弃抽象化过程中的
__jax_array__:目前 JAX 的抽象化过程(用于jit和其他转换)确实会解析__jax_array__协议,但这并非我们长远想要的行为。我们需要废弃此行为,并确保依赖它的下游软件包能够在必要时转向 PyTree 注册或显式数组转换。添加类型注解:
jax.numpy函数的类型接口位于jax/numpy/__init__.pyi中,我们需要将每个输入类型从ArrayLike改为ArrayLike | SupportsJAXArray,其中后者是一个带有__jax_array__方法的协议。我们不能将其直接添加到ArrayLike的定义中,因为ArrayLike也被用于不应支持__jax_array__的上下文中。文档:一旦添加了上述支持,我们应该增加一个关于数组可扩展性的文档章节,概述有关
__jax_array__协议的准确预期,并举例说明如何将其与 PyTree 注册结合使用,从而有效地与用户定义类型协作。