JEP 28661: 支持 __jax_array__
协议#
@jakevdp, 2025 年 5 月
偶尔有用户请求定义可与 JAX API 配合使用的自定义类数组对象的能力。JAX 当前通过在自定义对象上定义 __jax_array__
方法来部分实现这一机制。这从未打算成为一个负载关键的公共 API(参见 #4725 上的讨论),但对于 Keras 和 Flax 等明确说明其自定义数组对象可与 JAX 函数一起使用的包来说,它已变得至关重要。本 JEP 提出了一个设计方案,旨在全面、正式地支持 __jax_array__
协议。
数组可扩展性级别#
对 JAX 数组可扩展性的请求有多种形式
级别 1 可扩展性:多态输入#
我称之为“级别 1”的可扩展性是指 JAX API 接受多态输入的愿望。也就是说,用户期望这样的行为
class CustomArray:
data: numpy.ndarray
...
x = CustomArray(np.arange(5))
result = jnp.sin(x) # Converts `x` to JAX array and returns a JAX array
在这种可扩展性模型下,JAX 函数将接受 CustomArray 对象作为输入,为了计算而隐式地将它们转换为 jax.Array
对象。这类似于 NumPy 通过 __array__
方法提供的功能,以及 JAX(在许多但并非所有情况下)通过 __jax_array__
方法提供的功能。
这是 flax.nnx
及其他维护者所请求的可扩展性模式。当前的实现也由 JAX 内部用于符号维度的情况。
级别 2 可扩展性:多态输出#
我称之为“级别 2”的可扩展性是指 JAX API 不仅应该接受多态输入,还应该封装输出以匹配输入的类。也就是说,用户期望这样的行为
class CustomArray:
data: numpy.ndarray
...
x = CustomArray(np.arange(5))
result = jnp.sin(x) # returns a new CustomArray
在这种可扩展性模型下,JAX 函数不仅会接受自定义对象作为输入,还会通过某种协议来确定如何正确地将输出重新封装为相同的类。在 NumPy 中,这种功能在不同程度上由特殊的 __array_ufunc__
、__array_wrap__
和 __array_function__
协议提供,这些协议允许用户定义的对象自定义 NumPy API 函数如何对任意输入进行操作,以及如何将输入类型映射到输出。JAX 目前没有与 NumPy 中的这些接口等效的功能。
这是 Keras 等维护者所请求的可扩展性模式。
级别 3 可扩展性:子类化 Array
#
我称之为“级别 3”的可扩展性是指 JAX 数组对象本身可以被子类化。NumPy 提供了一些允许这样做的 API(参见 子类化 ndarray),但由于需要通过跟踪抽象地表示数组对象,这种方法在 JAX 中需要额外考虑。
这种可扩展性模式偶尔有用户请求,他们希望向 JAX 数组添加特殊元数据,例如测量单位。
概述#
为了本提案的目的,我们将坚持最简单的级别 1 可扩展性模型。提议的接口是目前许多 JAX API 非统一支持的 __jax_array__
方法。它的用法如下所示
import jax
import jax.numpy as jnp
import numpy as np
class CustomArray:
data: np.ndarray
def __init__(self, data: np.ndarray):
self.data = data
def __jax_array__(self) -> jax.Array:
return jnp.asarray(self.data)
arr = CustomArray(np.arange(5))
result = jnp.multiply(arr, 2)
print(repr(result))
# Array([0, 2, 4, 6, 8], dtype=int32)
我们将来可能会重新审视其他可扩展性级别。
设计挑战#
JAX 提出了与这种可扩展性相关的一些有趣的设计挑战,之前尚未完全探讨过。我们将在此依次讨论它们
__jax_array__
与 PyTree 展平的优先级#
JAX 已经有一个支持的自定义对象注册机制,即 pytree 注册(参见 扩展 pytrees)。如果我们也支持 jax_array,那么哪一个应该优先?
更具体地说,这段代码的结果应该是什么?
@jax.jit
def f(x):
print("is JAX array:", isinstance(x, jax.Array))
f(CustomArray(...))
如果我们选择在 JIT 边界处优先处理 __jax_array__
,那么此函数的输出将是
is JAX array: True
也就是说,在 JIT 边界处,CustomArray
对象将被转换为 __jax_array__
,并且其形状和 dtype 将用于为函数构造一个标准的 JAX 追踪器。
如果我们选择在 JIT 边界处优先处理 pytree 展平,那么此函数的输出将是
type(x)=CustomArray
也就是说,在 JIT 边界处,CustomArray
对象被展平,然后在传递给 JIT 编译的函数进行追踪之前再反展平。如果 CustomArray
已注册为 pytree,它通常会将其追踪的数组作为属性,当 x 传递给任何支持 __jax_array__
的 JAX API 时,这些追踪的属性将根据方法中指定的逻辑转换为单个追踪数组。
对于 vmap 和 grad 等其他变换在遇到自定义对象时如何工作,这里有更深层次的后果:例如,如果我们优先处理 pytree 展平,vmap 将对自定义对象的展平内容的维度进行操作,而如果我们优先处理 __jax_array__
,vmap 将对转换后的数组维度进行操作。
这对于 JIT 不变性也有影响:考虑这样一个函数
def f(x):
if isinstance(x, CustomArray):
return x.custom_method()
else:
# do something else
...
result1 = f(x)
result2 = jax.jit(f)(x)
如果 jit
通过 pytree 展平来处理 x
,那么结果应该与明确指定的展平规则一致。如果 jit
通过 __jax_array__
来处理 x
,结果将不同,因为在 JIT 编译的函数版本中,x
不再是 CustomArray。
概述#
截至 JAX v0.6.0 版本,变换在 __jax_array__
可用时会优先处理它。这种现状可能导致对 JIT 不变性缺失的困惑,并且当前的实际实现导致自动微分中出现细微的错误,即前向和后向传播对输入的处理不一致。
因为 pytree 可扩展性机制已经存在用于定制变换,所以如果变换只通过此机制进行,则显得最直接:也就是说,**我们建议在抽象化过程中移除对 __jax_array__
的解析。** 这种方法将通过变换保留对象身份,并为用户提供最大的灵活性。如果用户希望选择数组转换语义,则始终可以通过 jnp.asarray 显式转换其输入来实现,这将触发 __jax_array__
协议。
哪些 API 应支持 __jax_array__
?#
JAX 有许多不同级别的 API,从显式原始绑定级别(例如 jax.lax.add_p.bind(x, y)
)到 jax.lax
API(例如 jax.lax.add(x, y)
),再到 jax.numpy
API(例如 jax.numpy.add(x, y)
)。这些 API 类别中哪一个应该处理通过 __jax_array__
的隐式转换?
为了限制更改的范围和所需的测试,我建议 __jax_array__
仅在 jax.numpy
API 中得到显式支持:毕竟,它的灵感来自于 NumPy 包支持的 __array__
协议。如果需要,我们将来可以将其扩展到 jax.lax
API。
这与包的当前状态一致,其中 __jax_array__
的处理主要在 jax.numpy
API 使用的输入验证实用程序中。
实现#
考虑到这些设计选择,我们计划按以下方式实现
向
jax.numpy
添加运行时支持:这可能是最简单的部分,因为大多数jax.numpy
函数都使用一个通用的内部实用程序(ensure_arraylike
)来验证输入并将其转换为数组。该实用程序已经支持__jax_array__
,因此大多数 jax.numpy API 都已符合要求。增加测试覆盖率:为了确保所有 API 的兼容性,我们应该添加一个新的测试框架,该框架使用自定义输入调用每个
jax.numpy
API 并验证其正确行为。在抽象化过程中弃用
__jax_array__
:当前 JAX 的抽象化过程(用于jit
和其他变换)确实会解析__jax_array__
协议,但这并非我们长期希望的行为。我们需要弃用此行为,并确保依赖它的下游包可以在必要时转向 pytree 注册或显式数组转换。添加类型注解:jax.numpy 函数的类型接口位于
jax/numpy/__init__.pyi
中,我们需要将每个输入类型从ArrayLike
更改为ArrayLike | SupportsJAXArray
,其中后者是包含__jax_array__
方法的协议。我们不能直接将其添加到ArrayLike
定义中,因为ArrayLike
用在不应支持__jax_array__
的上下文中。文档:一旦添加了上述支持,我们应该添加一个关于数组可扩展性的文档章节,其中精确地概述了对
__jax_array__
协议的期望,并提供示例说明如何将其与 pytree 注册结合使用,以便有效地处理用户定义类型。