Jax 和 Jaxlib 版本管理#
为什么 jax 和 jaxlib 是单独的包?#
我们发布 JAX 为两个独立的 Python wheel 包,即 jax(纯 Python wheel)和 jaxlib(主要是 C++ wheel,包含以下库
XLA,
XLA 使用的一些 LLVM 组件,
MLIR 基础架构,例如 StableHLO Python 绑定。
JAX 特有的 C++ 库,用于快速 JIT 和 PyTree 操作。
我们之所以分发独立的 jax 和 jaxlib 包,是因为这使得在不构建 C++ 代码或甚至不需要安装 C++ 工具链的情况下,也能轻松地处理 JAX 的 Python 部分。 jaxlib 是一个庞大的库,许多用户难以构建,但 JAX 的大多数更改只涉及 Python 代码。通过允许 Python 部分独立于 C++ 部分进行更新,我们可以提高 Python 更改的开发速度。
此外,jaxlib 的构建成本很高,但我们希望能够在 CPU 资源不多的环境中(例如 Github Actions 或笔记本电脑)迭代和运行 JAX 测试。我们许多 CI 构建都只是使用预先构建的 jaxlib,而无需在每个 PR 上重新构建 JAX 的 C++ 部分。
正如我们将看到的,单独分发 jax 和 jaxlib 会带来一些成本,即它要求对 jaxlib 的更改必须保持向后兼容的 API。然而,我们认为总体而言,优先考虑简化 Python 更改,即使这样做会使 C++ 更改稍微复杂一些,也是值得的。
如何对 jax 和 jaxlib 进行版本管理?#
总结:在 JAX 源代码树中,jax 和 jaxlib 共享相同的版本号,但它们作为独立的 Python 包发布。安装时,jax 包的版本必须大于或等于 jaxlib 的版本,并且 jaxlib 的版本必须大于或等于 jax 指定的最低 jaxlib 版本。
jax 和 jaxlib 的发布版本都遵循 x.y.z 格式,其中 x 是主版本号,y 是次版本号,z 是可选的补丁版本。版本号必须遵循 PEP 440。版本号的比较是基于整数元组的字典序比较。
每个 jax 发布版本都有一个关联的最低 jaxlib 版本 mx.my.mz。对于 jax 版本 x.y.z,其最低 jaxlib 版本必须不大于 x.y.z。
为了使 jax 版本 x.y.z 和 jaxlib 版本 lx.ly.lz 兼容,必须满足以下条件:
jaxlib 版本 (
lx.ly.lz) 必须大于或等于最低 jaxlib 版本 (mx.my.mz)。jax 版本 (
x.y.z) 必须大于或等于 jaxlib 版本 (lx.ly.lz)。
这些约束条件意味着以下发布规则:
jax可以随时单独发布,而无需更新jaxlib。如果发布了新的
jaxlib,必须同时发布jax。
这些 版本约束 目前由 jax 在导入时检查,而不是表示为 Python 包的版本约束。 jax 在运行时检查 jaxlib 版本,而不是使用 pip 包版本约束,因为我们 为各种硬件和软件版本(例如 GPU、TPU 等)提供了独立的 jaxlib wheel。由于我们不知道哪种选择适合任何特定用户,因此我们不希望 pip 为我们自动安装 jaxlib 包。
未来,我们希望将 jaxlib 中与硬件相关的部分分离为单独的插件,届时最低版本可以表示为 Python 包依赖。目前,我们提供特定于平台的额外要求,用于安装兼容的 jaxlib 版本,例如 jax[cuda]。
如何安全地修改 jaxlib 的 API?#
jax可以在任何时候放弃对旧jaxlib版本的兼容性,只要将最低jaxlib版本提高到一个兼容的版本即可。但是,请注意,即使是未发布的jax版本,其最低jaxlib版本也必须是一个已发布的版本!这使我们能够在 CI 构建中使用已发布的jaxlibwheel,并允许 Python 开发者在 HEAD 版本上处理jax,而无需重新构建jaxlib。例如,要删除
jaxPython 代码中旧的向后兼容路径,只需将最低 jaxlib 版本提高,然后删除兼容性路径即可。jaxlib可以放弃对低于其自身发布版本号的旧jax版本的兼容性。由jax强制执行的版本约束将禁止使用不兼容的jaxlib。例如,为了让
jaxlib放弃一个被旧jax版本使用的 Python 绑定 API,必须递增jaxlib的次版本号或主版本号。如果可能,对
jaxlib的更改应以向后兼容的方式进行。总的来说,
jaxlib可以自由更改其 API,只要遵循关于jax与所有不低于最低版本(含)的jaxlib版本兼容的规则即可。这意味着jax必须始终兼容至少两个版本的jaxlib,即最后一个已发布版本和 HEAD 版本(实际上是下一个已发布版本)。如果保持兼容性,这会更容易做到,尽管可以使用来自jax的版本测试来进行不兼容的更改;见下文。例如,通常可以安全地向
jaxlib添加一个新函数,但如果当前jax仍在使用的现有函数被删除或更改其签名,则是不安全的。对jax的更改必须在不低于最低版本且不超过 HEAD 的所有jaxlib版本上正常工作或优雅降级。
请注意,这里的兼容性规则仅适用于 jax 和 jaxlib 的已发布版本。它们不适用于未发布的版本;也就是说,如果一个 API 从未发布,或者没有已发布的 jax 版本使用该 API,那么在 jaxlib 中引入然后删除该 API 是可以的。
jaxlib 的源码是如何组织的?#
jaxlib 分布在两个主要的代码库中,即主 JAX 仓库中的 和 XLA 源代码树(位于 XLA 仓库内)。XLA 中 JAX 特有的组件主要位于jaxlib/ 子目录xla/python 子目录。
JAX 的 C++ 组件(例如 Python 绑定和运行时组件)位于 XLA 树中的原因部分是历史原因,部分是技术原因。
历史原因在于,最初 xla/python 绑定被设想为可以与其他框架共享的通用 Python 绑定。实际上,这种情况越来越少,xla/python 包含了一些 JAX 特有的组件,并且可能还会包含更多。因此,最好简单地将 xla/python 视为 JAX 的一部分。
技术原因在于 XLA C++ API 不稳定。将 XLA:Python 绑定保留在 XLA 树中,可以使其 C++ 实现与 XLA 的 C++ API 一起原子地更新。维护 Python API 的向后和向前兼容性比 C++ API 更容易,因此 xla/python 会公开 Python API,并负责在 Python 层面维护向后兼容性。
jaxlib 是使用 Bazel 从 jax 仓库构建的。来自 XLA 仓库的 jaxlib 组件被作为 Bazel 子模块集成到构建中。要更新构建过程中使用的 XLA 版本,必须在 Bazel WORKSPACE 文件中更新固定的版本。这是按需手动完成的,但可以按次构建进行覆盖。
在发布新版本之间,我们如何在 jax 和 jaxlib 的边界处进行更改?#
jaxlib 版本是一个粗略的工具:它只能让我们推断发布版本。
然而,由于 jax 和 jaxlib 的代码分布在无法通过一次更改原子地更新的代码库中,我们需要管理比我们的发布周期更精细的兼容性。为了管理精细兼容性,我们有独立于 jaxlib 发布版本号的额外版本控制。
我们在XLA 仓库中的 中维护一个额外的版本号(xla_client.py_version)。这个版本号定义在 xla/python 中,并与 JAX 的 C++ 部分一起,也可以作为 jax._src.lib.jaxlib_extension_version 供 JAX Python 访问,并且每次对 XLA/Python 代码进行对 jax 具有向后兼容性影响的更改时,都必须递增。然后,JAX Python 代码可以使用此版本号来维护向后兼容性,例如:
from jax._src.lib import jaxlib_extension_version
# 123 is the new version number for _version in xla_client.py
if jaxlib_extension_version >= 123:
# Use new code path
...
else:
# Use old code path.
请注意,此版本号是除了已发布版本号约束之外的。也就是说,此版本号是为了帮助管理未发布代码的开发期间的兼容性。发布还必须遵循上述兼容性规则。