jax.extend:一个用于扩展的模块#
@froystig, @sharadmv, @jakevdp, @yashk2810
2023 年 5 月
import jax.extend as jex
一些项目依赖于 JAX 的代码库内部,通常是为了使用其核心机制(例如,编写一个在其 IR 上的转换)或对其进行扩展(例如,定义新的原语)。这些依赖项面临两个挑战:(a) 我们的内部实现并非都为外部使用而精心设计,(b) 绕过 JAX 的公共 API 是不受支持的。换句话说,我们的内部实现经常被用作库,但其结构和更新方式都不是库的风格。
本提案考虑jax.extend 模块,该模块定义了 JAX 内部组件的库视图。我们将此视为一个二等 API,仍然保证基本上无兼容性策略,但希望在发生更改时更容易注意到。
jax.extend 的目标受众包括 JAX 相关的 Python 库,如 Oryx、jax-triton 等,以及那些正在试验函数转换、自动微分系统、数值编程的编译器前端等项目。
本文档概述了 jax.extend 的现在和未来的可能样子。它不详细阐述,而是提议我们开始迭代开发该模块。
请注意,jax.extend 与 jax.experimental 不同,后者是正在进行的新功能和想法的试验场。通常,jax.experimental 中的工作最终会进入 JAX 的另一个模块或被完全移除。
无兼容性策略#
为降低开发开销,jax.extend 将不遵循公共API 兼容性策略。它不承诺弃用窗口,也不保证发布之间的向后兼容性。每次发布都可能中断现有调用者,且没有简单的补救措施(例如,没有一个标志可以重新引入之前的行为)。我们将依赖更改日志来调用这些更改。
需要与 JAX 版本同步定期升级代码的 jax.extend 调用者,可能会发现将 JAX 版本固定作为发布之间的一个中间步骤很有用。这对于今天依赖 JAX 内部机制的项目来说是一个常见的习惯。不同之处在于,现在它将得到更改日志公告和在库设计和命名方面更好意图的帮助。
迭代开发#
没有兼容性策略使得实现更容易开始:第一天,我们可以将一些符号从内部包(如 jax._src 以及今天的 jax.core 和 jax.interpreters)转移过来。然后我们可以从中迭代改进。
可能的模块概览#
我们可以设想,最终 jax.extend 将包含以下模块:
core– 原语、Jaxpr IR 等。interpreters– 核心转换(例如,自动微分、批量处理)和降低。random– 随机比特生成、密钥拆分和合并、密钥数组。sharding– 分布式数组的额外功能。
最初,我们可能还会在模块中包含其他符号,例如 jex.api_util,因为我们致力于移除或替换它们。其他内容将稍后决定。例如,jex.lib 可以提供对 jaxlib 的入口点(并在短期内这样做),但不清楚我们是否想长期保留它。
以下是一些关于每个模块可能包含什么的初步想法。
jax.extend.core#
这至少应使调用者能够定义新的 JAX 原语并处理 Jaxpr IR(jax.make_jaxpr(...) 的输出)。支持这一点可能涉及提供:
对现有核心系统原语的访问,例如今天的
jax._src.lax.add_p。对 IR 类型的访问,例如当前的
jax._src.core.ShapedArray。用于检查和漂亮打印 jaxprs 的函数。
用于显式构建 jaxprs 的函数,而不是通过
jax.make_jaxpr对 Python 函数进行暂存(或者不这样做!)。
在初始化时,此模块将包含比定义原语和规则所需更多的符号,包括用于设置“最终样式转换”的各种名称,例如当前的 jax._src.core.Trace 和 Tracer 类。我们可以重新考虑 jex.core 是否也应支持最终样式扩展和初始样式方法,以及它是否可以通过比完全公开 Trace 和 Tracer 更窄的 API 来实现。 Oryx 可能会帮助指导这些决定。
我们还可以考虑将 make_jaxpr 本身移到 jex.core。
jax.extend.interpreters#
此模块将提供一种注册原语的各种转换规则的方法——定义它们在自动微分、批量处理、降低等方面的行为。
它最初将反映 jax._src.interpreters,提供 ad、batching、partial_eval(用于 Python 到 Jaxpr 的暂存,以及用于 AD 中的线性化)、mlir、pxla 和 xla 模块。前三个可能可以由 jex.core 中的单个原语扩展 API 替换。后三个用于降低,也许可以简化为一个模块。
今天,要编写转换规则(例如,用于自动微分和批量处理),调用者可能需要与追踪器相关的符号,例如 JVPTracer 和 BatchTracer。这以后可能会避免,并允许我们从 jex 中移除追踪器类型。
此模块加上 jex.core 应该足以复制今天的自定义原语教程(例如,我们的教程和dfm 的教程)。例如,在短期内,可以按照以下方式定义一个原语及其在 jax.jit 下的行为:
from jax.extend import core # Previously: from jax import core
from jax.extend.interpreters import mlir # ... and similarly
mul_add_p = core.Primitive('mul_add')
mul_add_p.def_impl(lambda x, y, z: x * y + z)
@mul_add_p.def_abstract_eval
def mul_add_abstract(x_sa, y_sa, z_sa):
return core.ShapedArray(x_sa.shape, x_sa.dtype)
def mul_add_mlir(ctx, xc, yc, zc):
add = mlir.hlo.AddOp
mul = mlir.hlo.MulOp
return add(mul(xc, yc), zc).results
mlir.register_lowering(mul_add_p, mul_add_mlir)
import jax
print(mul_add_p.bind(2, 3, 4)) # -> 10
print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32)
jax.extend.random#
此模块可以公开我们定义新 RNG 实现的机制,以及用于处理 PRNG 密钥内部的函数(参见问题 #9263),例如今天的 jax._src.prng.random_wrap 和 random_unwrap。
它还可以公开内置 RNG 实现所依据的键控哈希函数,例如 jax._src.prng.threefry_2x32。
jax.extend.sharding#
此模块可以公开用于分片分布式数组的低级实用程序。
目前我们只有一个要考虑的项目。XLA 编译器的数组分片格式比 JAX 提供的格式更具表现力。我们可以将其提供为 jex.sharding.XlaOpShardingProto,对应于内部今天的 jax._src.lib.xla_client.OpSharding。