jax.extend
:一个用于扩展的模块#
@froystig, @sharadmv, @jakevdp, @yashk2810
2023 年 5 月
import jax.extend as jex
几个项目依赖于 JAX 的代码库内部结构,通常是为了使用其核心机制(例如,为 JAX 的 IR 编写一个转换),或扩展它(例如,定义新的原语)。这些依赖面临两个挑战:(a) 我们的内部结构并非都牢固地设计用于外部使用,以及 (b) 规避 JAX 的公共 API 是不受支持的。换句话说,我们的内部结构经常被当作库来使用,但它们既没有按照库的结构组织,也没有像库一样进行更新。
本提案考虑引入一个 jax.extend
模块,该模块定义了 JAX 某些内部组件的库视图。我们将把这视为一个二级 API,仍然基本不保证兼容性策略,但希望在发生变化时更容易发现。
jax.extend
的受众包括 JAX 周边的 Python 库,如 Oryx、jax-triton 以及许多其他库,以及那些正在试验函数转换、自动微分系统、数值编程编译器前端等项目。
本说明概述了 jax.extend
现在和将来可能的样子。它没有详细说明,而是建议我们开始迭代开发该模块。
请注意,jax.extend
不同于 jax.experimental
,后者是正在开发的新功能和新想法的试验场。通常,jax.experimental
中的工作最终会进入另一个 JAX 模块或完全被移除。
无兼容性策略#
为了保持较低的开发开销,jax.extend
将不遵循公共的API 兼容性策略。它不承诺弃用窗口,也不保证版本之间的向后兼容性。每个版本都可能破坏现有调用方,而没有简单的补救措施(例如,没有标志来重新引入先前的行为)。我们将依赖更改日志来指出此类更改。
需要与 JAX 版本同步定期升级代码的 jax.extend
调用方可能会发现,在版本之间固定 JAX 版本作为中间步骤会很有用。这是目前依赖 JAX 内部结构的项目中的一个常见习惯。不同之处在于,现在它将得到更改日志公告的帮助,并且在库设计和命名方面有更好的意图。
迭代式开发#
没有兼容性策略使得开始实施变得更容易:在第一天,我们可以从诸如 jax._src
以及当前的 jax.core
和 jax.interpreters
等内部包中移出少量符号。然后我们可以在此基础上进行迭代改进。
可能的模块概述#
我们可以想象 jax.extend
最终将包含以下模块:
core
– 原语、Jaxpr IR 等。interpreters
– 核心转换(例如自动微分、批处理)和降低。random
– 随机位生成、密钥拆分与折叠、密钥数组。sharding
– 分布式数组的额外功能。
最初,模块中可能还会包含其他符号,例如 jex.api_util
,随着我们努力将其移除或替换。其他符号将适时决定。例如,jex.lib
可以提供一个通往 jaxlib 的入口点(短期内会如此),但我们是否会长期保留它尚不清楚。
下面是关于这些模块可能包含内容的初步思考。
jax.extend.core
#
这应该至少允许调用方定义新的 JAX 原语并处理 Jaxpr IR(jax.make_jaxpr(...)
的输出)。支持这一点可能涉及提供:
访问现有的核心系统原语,例如当前的
jax._src.lax.add_p
。访问 IR 类型,例如当前的
jax._src.core.ShapedArray
。用于检查和美化打印 jaxpr 的函数。
用于显式构建 jaxpr 的函数,而不是通过
jax.make_jaxpr
暂存 Python 函数(或者不暂存!)。
初始化时,该模块将包含比定义原语和规则所需的更多符号,包括用于设置“最终样式转换”的各种名称,例如当前的 jax._src.core.Trace
和 Tracer
类。我们可以重新考虑 jex.core
是否也应该支持最终样式扩展以及初始样式方法,以及它是否可以通过比完全暴露 Trace
和 Tracer
更窄的 API 来实现这一点。Oryx 可能会帮助指导这些决定。
我们还可以考虑将 make_jaxpr
本身重新定位到 jex.core
。
jax.extend.interpreters
#
该模块将提供一种注册原语的各种转换规则的方法——定义它们在自动微分 (AD)、批处理、降低等操作下的行为。
它最初将反映 jax._src.interpreters
,提供 ad
、batching
、partial_eval
(用于将 Python 暂存到 Jaxpr,以及在 AD 中进行线性化)、mlir
、pxla
和 xla
模块。前三个模块可能可以被 jex.core
中的单个原语扩展 API 所取代。后三个用于降低的模块,也许可以简化为一个模块。
目前,要编写转换规则,例如用于 AD 和批处理的规则,调用方可能需要与跟踪器(tracer)相关的符号,例如 JVPTracer
和 BatchTracer
。这在以后可能可以避免,并允许我们从 jex
中移除跟踪器类型。
该模块加上 jex.core
应该足以重现当前的自定义原语教程(例如我们的和dfm 的)。例如,定义一个原语及其在 jax.jit
下的行为将可能如下所示(短期内):
from jax.extend import core # Previously: from jax import core
from jax.extend.interpreters import mlir # ... and similarly
mul_add_p = core.Primitive('mul_add')
mul_add_p.def_impl(lambda x, y, z: x * y + z)
@mul_add_p.def_abstract_eval
def mul_add_abstract(x_sa, y_sa, z_sa):
return core.ShapedArray(x_sa.shape, x_sa.dtype)
def mul_add_mlir(ctx, xc, yc, zc):
add = mlir.hlo.AddOp
mul = mlir.hlo.MulOp
return add(mul(xc, yc), zc).results
mlir.register_lowering(mul_add_p, mul_add_mlir)
import jax
print(mul_add_p.bind(2, 3, 4)) # -> 10
print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32)
jax.extend.random
#
该模块可以公开我们定义新 RNG 实现的机制,以及用于处理 PRNG 密钥内部结构(参见问题 #9263)的函数,例如当前的 jax._src.prng.random_wrap
和 random_unwrap
。
它还可以公开构成内置 RNG 实现基础的带密钥哈希函数,例如 jax._src.prng.threefry_2x32
。
jax.extend.sharding
#
该模块可以公开用于分片分布式数组的低级实用程序。
目前我们只考虑了一项。XLA 编译器的数组分片格式比JAX 提供的更具表现力。我们可以将其提供为 jex.sharding.XlaOpShardingProto
,它对应于当前的内部 jax._src.lib.xla_client.OpSharding
。