JEP 28661: 支持 __jax_array__ 协议#

@jakevdp, 2025 年 5 月

偶尔有用户请求定义可与 JAX API 配合使用的自定义类数组对象的能力。JAX 当前通过在自定义对象上定义 __jax_array__ 方法来部分实现这一机制。这从未打算成为一个负载关键的公共 API(参见 #4725 上的讨论),但对于 Keras 和 Flax 等明确说明其自定义数组对象可与 JAX 函数一起使用的包来说,它已变得至关重要。本 JEP 提出了一个设计方案,旨在全面、正式地支持 __jax_array__ 协议。

数组可扩展性级别#

对 JAX 数组可扩展性的请求有多种形式

级别 1 可扩展性:多态输入#

我称之为“级别 1”的可扩展性是指 JAX API 接受多态输入的愿望。也就是说,用户期望这样的行为

class CustomArray:
  data: numpy.ndarray
  ...

x = CustomArray(np.arange(5))
result = jnp.sin(x)  # Converts `x` to JAX array and returns a JAX array

在这种可扩展性模型下,JAX 函数将接受 CustomArray 对象作为输入,为了计算而隐式地将它们转换为 jax.Array 对象。这类似于 NumPy 通过 __array__ 方法提供的功能,以及 JAX(在许多但并非所有情况下)通过 __jax_array__ 方法提供的功能。

这是 flax.nnx 及其他维护者所请求的可扩展性模式。当前的实现也由 JAX 内部用于符号维度的情况。

级别 2 可扩展性:多态输出#

我称之为“级别 2”的可扩展性是指 JAX API 不仅应该接受多态输入,还应该封装输出以匹配输入的类。也就是说,用户期望这样的行为

class CustomArray:
  data: numpy.ndarray
  ...

x = CustomArray(np.arange(5))
result = jnp.sin(x)  # returns a new CustomArray

在这种可扩展性模型下,JAX 函数不仅会接受自定义对象作为输入,还会通过某种协议来确定如何正确地将输出重新封装为相同的类。在 NumPy 中,这种功能在不同程度上由特殊的 __array_ufunc____array_wrap____array_function__ 协议提供,这些协议允许用户定义的对象自定义 NumPy API 函数如何对任意输入进行操作,以及如何将输入类型映射到输出。JAX 目前没有与 NumPy 中的这些接口等效的功能。

这是 Keras 等维护者所请求的可扩展性模式。

级别 3 可扩展性:子类化 Array#

我称之为“级别 3”的可扩展性是指 JAX 数组对象本身可以被子类化。NumPy 提供了一些允许这样做的 API(参见 子类化 ndarray),但由于需要通过跟踪抽象地表示数组对象,这种方法在 JAX 中需要额外考虑。

这种可扩展性模式偶尔有用户请求,他们希望向 JAX 数组添加特殊元数据,例如测量单位。

概述#

为了本提案的目的,我们将坚持最简单的级别 1 可扩展性模型。提议的接口是目前许多 JAX API 非统一支持的 __jax_array__ 方法。它的用法如下所示

import jax
import jax.numpy as jnp
import numpy as np

class CustomArray:
  data: np.ndarray

  def __init__(self, data: np.ndarray):
    self.data = data

  def __jax_array__(self) -> jax.Array:
    return jnp.asarray(self.data)

arr = CustomArray(np.arange(5))
result = jnp.multiply(arr, 2)
print(repr(result))
# Array([0, 2, 4, 6, 8], dtype=int32)

我们将来可能会重新审视其他可扩展性级别。

设计挑战#

JAX 提出了与这种可扩展性相关的一些有趣的设计挑战,之前尚未完全探讨过。我们将在此依次讨论它们

__jax_array__ 与 PyTree 展平的优先级#

JAX 已经有一个支持的自定义对象注册机制,即 pytree 注册(参见 扩展 pytrees)。如果我们也支持 jax_array,那么哪一个应该优先?

更具体地说,这段代码的结果应该是什么?

@jax.jit
def f(x):
  print("is JAX array:", isinstance(x, jax.Array))

f(CustomArray(...))

如果我们选择在 JIT 边界处优先处理 __jax_array__,那么此函数的输出将是

is JAX array: True

也就是说,在 JIT 边界处,CustomArray 对象将被转换为 __jax_array__,并且其形状和 dtype 将用于为函数构造一个标准的 JAX 追踪器。

如果我们选择在 JIT 边界处优先处理 pytree 展平,那么此函数的输出将是

type(x)=CustomArray

也就是说,在 JIT 边界处,CustomArray 对象被展平,然后在传递给 JIT 编译的函数进行追踪之前再反展平。如果 CustomArray 已注册为 pytree,它通常会将其追踪的数组作为属性,当 x 传递给任何支持 __jax_array__ 的 JAX API 时,这些追踪的属性将根据方法中指定的逻辑转换为单个追踪数组。

对于 vmap 和 grad 等其他变换在遇到自定义对象时如何工作,这里有更深层次的后果:例如,如果我们优先处理 pytree 展平,vmap 将对自定义对象的展平内容的维度进行操作,而如果我们优先处理 __jax_array__,vmap 将对转换后的数组维度进行操作。

这对于 JIT 不变性也有影响:考虑这样一个函数

def f(x):
  if isinstance(x, CustomArray):
    return x.custom_method()
  else:
    # do something else
    ...

result1 = f(x)
result2 = jax.jit(f)(x)

如果 jit 通过 pytree 展平来处理 x,那么结果应该与明确指定的展平规则一致。如果 jit 通过 __jax_array__ 来处理 x,结果将不同,因为在 JIT 编译的函数版本中,x 不再是 CustomArray。

概述#

截至 JAX v0.6.0 版本,变换在 __jax_array__ 可用时会优先处理它。这种现状可能导致对 JIT 不变性缺失的困惑,并且当前的实际实现导致自动微分中出现细微的错误,即前向和后向传播对输入的处理不一致。

因为 pytree 可扩展性机制已经存在用于定制变换,所以如果变换只通过此机制进行,则显得最直接:也就是说,**我们建议在抽象化过程中移除对 __jax_array__ 的解析。** 这种方法将通过变换保留对象身份,并为用户提供最大的灵活性。如果用户希望选择数组转换语义,则始终可以通过 jnp.asarray 显式转换其输入来实现,这将触发 __jax_array__ 协议。

哪些 API 应支持 __jax_array__#

JAX 有许多不同级别的 API,从显式原始绑定级别(例如 jax.lax.add_p.bind(x, y))到 jax.lax API(例如 jax.lax.add(x, y)),再到 jax.numpy API(例如 jax.numpy.add(x, y))。这些 API 类别中哪一个应该处理通过 __jax_array__ 的隐式转换?

为了限制更改的范围和所需的测试,我建议 __jax_array__ 仅在 jax.numpy API 中得到显式支持:毕竟,它的灵感来自于 NumPy 包支持的 __array__ 协议。如果需要,我们将来可以将其扩展到 jax.lax API。

这与包的当前状态一致,其中 __jax_array__ 的处理主要在 jax.numpy API 使用的输入验证实用程序中。

实现#

考虑到这些设计选择,我们计划按以下方式实现

  • jax.numpy 添加运行时支持:这可能是最简单的部分,因为大多数 jax.numpy 函数都使用一个通用的内部实用程序(ensure_arraylike)来验证输入并将其转换为数组。该实用程序已经支持 __jax_array__,因此大多数 jax.numpy API 都已符合要求。

  • 增加测试覆盖率:为了确保所有 API 的兼容性,我们应该添加一个新的测试框架,该框架使用自定义输入调用每个 jax.numpy API 并验证其正确行为。

  • 在抽象化过程中弃用 __jax_array__:当前 JAX 的抽象化过程(用于 jit 和其他变换)确实会解析 __jax_array__ 协议,但这并非我们长期希望的行为。我们需要弃用此行为,并确保依赖它的下游包可以在必要时转向 pytree 注册或显式数组转换。

  • 添加类型注解:jax.numpy 函数的类型接口位于 jax/numpy/__init__.pyi 中,我们需要将每个输入类型从 ArrayLike 更改为 ArrayLike | SupportsJAXArray,其中后者是包含 __jax_array__ 方法的协议。我们不能直接将其添加到 ArrayLike 定义中,因为 ArrayLike 用在不应支持 __jax_array__ 的上下文中。

  • 文档:一旦添加了上述支持,我们应该添加一个关于数组可扩展性的文档章节,其中精确地概述了对 __jax_array__ 协议的期望,并提供示例说明如何将其与 pytree 注册结合使用,以便有效地处理用户定义类型。