JEP 28661:支持 `__jax_array__` 协议#

@jakevdp, 2025 年 5 月

偶尔会有用户请求能够定义与 jax API 兼容的类数组对象。JAX 目前通过自定义对象上定义的 `__jax_array__` 方法对此机制有一个部分实现。这从未打算成为一个承重型公共 API(请参阅 #4725 的讨论),但它已成为 Keras 和 flax 等包的必需品,这些包明确记录了能够将其自定义数组对象与 jax 函数一起使用。本 JEP 提议了一个完全支持并文档化的 `__jax_array__` 协议的设计。

数组可扩展性的级别#

JAX 数组的可扩展性请求有几种形式

级别 1 可扩展性:多态输入#

我将称之为“级别 1”的可扩展性是希望 JAX API 接受多态输入。也就是说,用户希望行为如下:

class CustomArray:
  data: numpy.ndarray
  ...

x = CustomArray(np.arange(5))
result = jnp.sin(x)  # Converts `x` to JAX array and returns a JAX array

在此可扩展性模型下,JAX 函数将接受 CustomArray 对象作为输入,并隐式将其转换为 `jax.Array` 对象以进行计算。这类似于 NumPy 通过 `__array__` 方法提供的功能,以及在 JAX 中(在许多但并非所有情况下)通过 `__jax_array__` 方法提供的功能。

这是 `flax.nnx` 等的维护者所要求的可扩展性模式。当前实现也由 JAX 内部用于符号维度的情况。

级别 2 可扩展性:多态输出#

我将称之为“级别 2”的可扩展性是希望 JAX API 不仅接受多态输入,还包装输出以匹配输入的类。也就是说,用户希望行为如下:

class CustomArray:
  data: numpy.ndarray
  ...

x = CustomArray(np.arange(5))
result = jnp.sin(x)  # returns a new CustomArray

在此可扩展性模型下,JAX 函数不仅会接受自定义对象作为输入,还会有一个协议来确定如何正确地用相同的类重新包装输出。在 NumPy 中,这种功能由特殊的 `__array_ufunc__`、`__array_wrap__` 和 `__array_function__` 协议在不同程度上提供,这些协议允许用户定义的对象自定义 NumPy API 函数如何操作任意输入并将输入类型映射到输出。JAX 目前没有等同于 NumPy 中这些接口的内容。

这是 `keras` 等的维护者所要求的可扩展性模式。

级别 3 可扩展性:子类化 `Array`#

我将称之为“级别 3”的可扩展性是希望 JAX 数组对象本身可以被子类化。NumPy 提供了一些允许这样做的 API(请参阅 Subclassing ndarray),但由于需要通过跟踪抽象地表示数组对象,这种方法在 JAX 中需要一些额外的考虑。

这种可扩展性模式偶尔会被用户请求,他们希望向 JAX 数组添加特殊元数据,例如测量单位。

摘要#

为了本提案的目的,我们将坚持最简单的级别 1 可扩展性模型。所提出的接口是当前许多 JAX API 非统一支持的接口,即 `__jax_array__` 方法。其用法如下:

import jax
import jax.numpy as jnp
import numpy as np

class CustomArray:
  data: np.ndarray

  def __init__(self, data: np.ndarray):
    self.data = data

  def __jax_array__(self) -> jax.Array:
    return jnp.asarray(self.data)

arr = CustomArray(np.arange(5))
result = jnp.multiply(arr, 2)
print(repr(result))
# Array([0, 2, 4, 6, 8], dtype=int32)

我们可能会在未来重新考虑其他可扩展性级别。

设计挑战#

JAX 为这种可扩展性带来了一些有趣的设计挑战,这些挑战以前并未得到充分探索。我们将在此逐一讨论。

`__jax_array__` 与 PyTree 展平的优先级#

JAX 已经有一个支持的注册自定义对象的机制,即 pytree 注册(请参阅 Extending pytrees)。如果我们还支持 **jax_array**,哪一个应该优先?

更具体地说,这段代码的结果应该是什么?

@jax.jit
def f(x):
  print("is JAX array:", isinstance(x, jax.Array))

f(CustomArray(...))

如果我们选择在 JIT 边界处优先考虑 `__jax_array__`,那么该函数的结果将是:

is JAX array: True

也就是说,在 JIT 边界处,`CustomArray` 对象将被转换为 `__jax_array__`,其形状和 dtype 将用于为函数构建标准的 JAX 追踪器。

如果我们选择在 JIT 边界处优先考虑 pytree 展平,那么该函数的结果将是:

type(x)=CustomArray

也就是说,在 JIT 边界处,`CustomArray` 对象被展平,然后在传递给 JIT 编译的函数进行追踪之前被取消展平。如果 `CustomArray` 已注册为 pytree,它通常会将其属性作为追踪数组,并且当 x 被传递到任何支持 `__jax_array__` 的 JAX API 时,这些追踪属性将根据方法中指定的逻辑转换为单个追踪数组。

这里有更深层次的后果,关于其他转换(如 vmap 和 grad)在遇到自定义对象时如何工作:例如,如果我们优先考虑 pytree 展平,vmap 将在自定义对象的展平内容维度上操作,而如果我们优先考虑 `__jax_array__`,vmap 将在转换后的数组维度上操作。

这也对 JIT 不变性有影响:考虑这样的函数:

def f(x):
  if isinstance(x, CustomArray):
    return x.custom_method()
  else:
    # do something else
    ...

result1 = f(x)
result2 = jax.jit(f)(x)

如果 `jit` 通过 pytree 展平消耗 `x`,那么对于一个定义良好的展平规则,结果应该是一致的。如果 `jit` 通过 `__jax_array__` 消耗 `x`,结果将不同,因为在函数的 JIT 编译版本中 `x` 不再是 CustomArray。

摘要#

截至 JAX v0.6.0,转换在 `__jax_array__` 可用时优先处理它。这种现状可能导致对 JIT 不变性缺乏的困惑,并且当前的实现实际上会导致自动微分中的细微错误,其中前向和后向传递不对输入进行一致处理。

由于 pytree 可扩展性机制已经存在用于自定义转换的情况,因此转换仅通过此机制起作用似乎最直接:也就是说,**我们建议在抽象化期间移除 `__jax_array__` 解析。** 这种方法将通过转换保留对象身份,并为用户提供最大的灵活性。如果用户希望选择加入数组转换语义,则始终可以通过显式使用 jnp.asarray 进行输入转换来实现,这将触发 `__jax_array__` 协议。

哪些 API 应支持 `__jax_array__`?#

JAX 具有不同级别的 API,从显式原始绑定(例如 `jax.lax.add_p.bind(x, y)`)到 `jax.lax` API(例如 `jax.lax.add(x, y)`)再到 `jax.numpy` API(例如 `jax.numpy.add(x, y)`)。这些 API 类别中的哪一个应通过 `__jax_array__` 处理隐式转换?

为了限制更改的范围和所需的测试,我建议 `__jax_array__` 仅在 `jax.numpy` API 中明确支持:毕竟,它受到 NumPy 包支持的 `__array__` 协议的启发。如果将来需要,我们可以将其扩展到 `jax.lax` API。

这与包的当前状态一致,其中 `__jax_array__` 处理主要在 `jax.numpy` API 使用的输入验证实用程序中。

实施#

考虑到这些设计选择,我们计划如下实施:

  • 向 `jax.numpy` 添加运行时支持:这可能是最容易的部分,因为大多数 `jax.numpy` 函数都使用一个通用的内部实用程序(`ensure_arraylike`)来验证输入并将其转换为数组。此实用程序已支持 `__jax_array__`,因此大多数 jax.numpy API 已符合要求。

  • 添加测试覆盖:为了确保跨 API 的合规性,我们应该添加一个新的测试框架,该框架使用自定义输入调用每个 `jax.numpy` API,并验证其行为是否正确。

  • 在抽象化期间弃用 `__jax_array__`:目前 JAX 的抽象化过程,用于 `jit` 和其他转换,确实会解析 `__jax_array__` 协议,但这并非我们长期想要的。我们需要弃用此行为,并确保依赖它的下游包能够迁移到 pytree 注册或显式数组转换(如果需要)。

  • 添加类型注释:jax.numpy 函数的类型接口位于 `jax/numpy/__init__.pyi` 中,我们需要将每个输入类型从 `ArrayLike` 更改为 `ArrayLike | SupportsJAXArray`,其中后者是一个带有 `__jax_array__` 方法的协议。我们不能将其直接添加到 `ArrayLike` 定义中,因为 `ArrayLike` 用于 `__jax_array__` 不应被支持的上下文。

  • 文档:一旦添加了上述支持,我们就应该添加一个关于数组可扩展性的文档部分,其中准确概述了有关 `__jax_array__` 协议的预期,并提供了有关如何结合 pytree 注册使用它的示例,以便有效地与用户定义的类型一起工作。