Jax 和 Jaxlib 版本控制#

为什么 jaxjaxlib 是独立的包?#

我们以两个独立的 Python wheel 包发布 JAX,即纯 Python wheel 包 jax,以及主要由 C++ 构成的 wheel 包 jaxlib,后者包含以下库:

  • XLA,

  • XLA 使用的部分 LLVM,

  • MLIR 基础设施,例如 StableHLO Python 绑定。

  • 用于快速 JIT 和 PyTree 操作的 JAX 专用 C++ 库。

我们分发独立的 jaxjaxlib 包,因为它使得在 JAX 的 Python 部分上工作变得容易,而无需构建 C++ 代码甚至无需安装 C++ 工具链。jaxlib 是一个大型库,许多用户不容易构建,但 JAX 的大多数更改只涉及 Python 代码。通过允许 Python 部分独立于 C++ 部分进行更新,我们提高了 Python 更改的开发速度。

此外,jaxlib 的构建成本不菲,但我们希望能够在没有大量 CPU 的环境中(例如在 Github Actions 或笔记本电脑上)迭代并运行 JAX 测试。我们的许多 CI 构建简单地使用预构建的 jaxlib,而不是在每个 PR 上都重新构建 JAX 的 C++ 部分。

正如我们将看到的,单独分发 jaxjaxlib 会带来一定的代价,因为它要求对 jaxlib 的更改保持 API 的向后兼容性。然而,我们认为总的来说,让 Python 更改变得容易更可取,即使这会使 C++ 更改稍微困难。

jaxjaxlib 如何进行版本控制?#

摘要:jaxjaxlib 在 JAX 源代码树中共享相同的版本号,但作为独立的 Python 包发布。安装时,jax 包版本必须大于或等于 jaxlib 的版本,并且 jaxlib 的版本必须大于或等于 jax 指定的最低 jaxlib 版本。

jaxjaxlib 的发布版本号均为 x.y.z,其中 x 是主版本号,y 是次版本号,z 是可选的补丁版本。版本号必须遵循 PEP 440。版本号比较是对整数元组进行的字典序比较。

每个 jax 发布版本都有一个关联的最低 jaxlib 版本 mx.my.mzjax 版本 x.y.z 的最低 jaxlib 版本不得大于 x.y.z

为了使 jax 版本 x.y.zjaxlib 版本 lx.ly.lz 兼容,必须满足以下条件:

  • jaxlib 版本 (lx.ly.lz) 必须大于或等于最低 jaxlib 版本 (mx.my.mz)。

  • jax 版本 (x.y.z) 必须大于或等于 jaxlib 版本 (lx.ly.lz)。

这些约束意味着以下发布规则:

  • jax 可以随时独立发布,无需更新 jaxlib

  • 如果发布了新的 jaxlib,则必须同时发布 jax

这些版本约束当前由 jax 在导入时检查,而不是表示为 Python 包版本约束。jax 在运行时检查 jaxlib 版本,而不是使用 pip 包版本约束,因为我们提供独立的 jaxlib wheel 包,以支持各种硬件和软件版本(例如,GPU、TPU 等)。由于我们不知道哪种选择适合任何给定用户,我们不希望 pip 自动为我们安装 jaxlib 包。

未来,我们希望将 jaxlib 的硬件特定部分分离成独立的插件,届时最低版本可以表示为 Python 包依赖。目前,我们确实提供了平台特定的额外要求,用于安装兼容的 jaxlib 版本,例如 jax[cuda]

如何安全地修改 jaxlib 的 API?#

  • jax 可以在任何时候放弃对旧 jaxlib 版本的兼容性,只要最低 jaxlib 版本增加到兼容版本即可。但是,请注意,最低 jaxlib 版本,即使对于 jax 的未发布版本,也必须是已发布版本!这使得我们可以在 CI 构建中使用已发布的 jaxlib wheel 包,并允许 Python 开发者在 HEAD 上开发 jax 而无需构建 jaxlib

    例如,要移除 jax Python 代码中旧的向后兼容路径,只需提高最低 jaxlib 版本,然后删除兼容路径即可。

  • jaxlib 可以放弃对其自身发布版本号更低的旧 jax 版本的兼容性。jax 强制执行的版本约束将禁止使用不兼容的 jaxlib

    例如,如果 jaxlib 要放弃旧 jax 版本使用的 Python 绑定 API,则必须递增 jaxlib 的次版本号或主版本号。

  • 如果可能,对 jaxlib 的更改应以向后兼容的方式进行。

    通常情况下,jaxlib 可以自由更改其 API,只要遵循 jax 兼容所有至少与最低版本一样新的 jaxlib 的规则即可。这意味着 jax 必须始终兼容至少两个版本的 jaxlib,即上一个发布版本和树尖版本(实际上是下一个发布版本)。如果保持兼容性,这样做会更容易,尽管也可以使用 jax 的版本测试进行不兼容的更改;详见下文。

    例如,向 jaxlib 添加新函数通常是安全的,但如果当前的 jax 仍在 F使用该函数,则移除现有函数或更改其签名是不安全的。对 jax 的更改必须适用于所有大于最低版本直到 HEAD 的 jaxlib 发布版本,或者在此类版本中优雅降级。

请注意,这里的兼容性规则仅适用于 *已发布* 版本的 jaxjaxlib。它们不适用于未发布版本;也就是说,如果 jaxlib 中的某个 API 从未发布,或者没有已发布的 jax 版本使用该 API,则可以引入然后移除该 API。

jaxlib 的源代码是如何布局的?#

jaxlib 分布在两个主要仓库中,即主 JAX 仓库中的 jaxlib/ 子目录XLA 仓库中的 XLA 源代码树。XLA 中 JAX 特定的部分主要在 xla/python 子目录中。

JAX 的 C++ 部分(例如 Python 绑定和运行时组件)位于 XLA 树中的原因,部分是历史原因,部分是技术原因。

历史原因是,最初 xla/python 绑定被设想为通用的 Python 绑定,可能会与其他框架共享。实际上,这种情况越来越少见,xla/python 包含了许多 JAX 特定的部分,并且可能会包含更多。因此,最好简单地将 xla/python 视为 JAX 的一部分。

技术原因是 XLA C++ API 不稳定。通过将 XLA:Python 绑定保留在 XLA 树中,它们的 C++ 实现可以与 XLA 的 C++ API 原子更新。Python API 比 C++ API 更容易维护向后和向前兼容性,因此 xla/python 暴露 Python API 并负责维护 Python 级别的向后兼容性。

jaxlib 是使用 Bazel 从 jax 仓库中构建的。XLA 仓库中的 jaxlib 部分作为 Bazel 子模块被纳入构建。要更新构建过程中使用的 XLA 版本,必须更新 Bazel WORKSPACE 中的固定版本。这是根据需要手动完成的,但可以在每次构建时覆盖。

在发布之间,我们如何跨 jaxjaxlib 边界进行更改?#

jaxlib 版本是一个粗略的工具:它只允许我们推断 *发布版本*。

然而,由于 jaxjaxlib 代码分散在无法通过一次更改进行原子更新的仓库中,我们需要以比发布周期更精细的粒度管理兼容性。为了管理细粒度兼容性,我们有独立于 jaxlib 发布版本号的额外版本控制。

我们在XLA 仓库中的 xla_client.py 中维护了一个额外的版本号(_version)。其思想是,此版本号在 xla/python 中与 JAX 的 C++ 部分一起定义,并且 JAX Python 也可以通过 jax._src.lib.jaxlib_extension_version 访问,并且每次对 XLA/Python 代码进行具有向后兼容性影响的更改时都必须递增。JAX Python 代码随后可以使用此版本号来维护向后兼容性,例如:

from jax._src.lib import jaxlib_extension_version

# 123 is the new version number for _version in xla_client.py
if jaxlib_extension_version >= 123:
  # Use new code path
  ...
else:
  # Use old code path.

请注意,此版本号是除了对已发布版本号的约束 *之外* 的。也就是说,此版本号旨在帮助管理未发布代码在开发过程中的兼容性。发布版本也必须遵循上述兼容性规则。