Jax 和 Jaxlib 版本控制#
为什么 jax
和 jaxlib
是独立的包?#
我们以两个独立的 Python wheel 包发布 JAX,即纯 Python wheel 包 jax
,以及主要由 C++ 构成的 wheel 包 jaxlib
,后者包含以下库:
XLA,
XLA 使用的部分 LLVM,
MLIR 基础设施,例如 StableHLO Python 绑定。
用于快速 JIT 和 PyTree 操作的 JAX 专用 C++ 库。
我们分发独立的 jax
和 jaxlib
包,因为它使得在 JAX 的 Python 部分上工作变得容易,而无需构建 C++ 代码甚至无需安装 C++ 工具链。jaxlib
是一个大型库,许多用户不容易构建,但 JAX 的大多数更改只涉及 Python 代码。通过允许 Python 部分独立于 C++ 部分进行更新,我们提高了 Python 更改的开发速度。
此外,jaxlib
的构建成本不菲,但我们希望能够在没有大量 CPU 的环境中(例如在 Github Actions 或笔记本电脑上)迭代并运行 JAX 测试。我们的许多 CI 构建简单地使用预构建的 jaxlib
,而不是在每个 PR 上都重新构建 JAX 的 C++ 部分。
正如我们将看到的,单独分发 jax
和 jaxlib
会带来一定的代价,因为它要求对 jaxlib
的更改保持 API 的向后兼容性。然而,我们认为总的来说,让 Python 更改变得容易更可取,即使这会使 C++ 更改稍微困难。
jax
和 jaxlib
如何进行版本控制?#
摘要:jax
和 jaxlib
在 JAX 源代码树中共享相同的版本号,但作为独立的 Python 包发布。安装时,jax
包版本必须大于或等于 jaxlib
的版本,并且 jaxlib
的版本必须大于或等于 jax
指定的最低 jaxlib
版本。
jax
和 jaxlib
的发布版本号均为 x.y.z
,其中 x
是主版本号,y
是次版本号,z
是可选的补丁版本。版本号必须遵循 PEP 440。版本号比较是对整数元组进行的字典序比较。
每个 jax
发布版本都有一个关联的最低 jaxlib
版本 mx.my.mz
。jax
版本 x.y.z
的最低 jaxlib
版本不得大于 x.y.z
。
为了使 jax
版本 x.y.z
和 jaxlib
版本 lx.ly.lz
兼容,必须满足以下条件:
jaxlib 版本 (
lx.ly.lz
) 必须大于或等于最低 jaxlib 版本 (mx.my.mz
)。jax 版本 (
x.y.z
) 必须大于或等于 jaxlib 版本 (lx.ly.lz
)。
这些约束意味着以下发布规则:
jax
可以随时独立发布,无需更新jaxlib
。如果发布了新的
jaxlib
,则必须同时发布jax
。
这些版本约束当前由 jax
在导入时检查,而不是表示为 Python 包版本约束。jax
在运行时检查 jaxlib
版本,而不是使用 pip
包版本约束,因为我们提供独立的 jaxlib
wheel 包,以支持各种硬件和软件版本(例如,GPU、TPU 等)。由于我们不知道哪种选择适合任何给定用户,我们不希望 pip
自动为我们安装 jaxlib
包。
未来,我们希望将 jaxlib
的硬件特定部分分离成独立的插件,届时最低版本可以表示为 Python 包依赖。目前,我们确实提供了平台特定的额外要求,用于安装兼容的 jaxlib 版本,例如 jax[cuda]
。
如何安全地修改 jaxlib
的 API?#
jax
可以在任何时候放弃对旧jaxlib
版本的兼容性,只要最低jaxlib
版本增加到兼容版本即可。但是,请注意,最低jaxlib
版本,即使对于jax
的未发布版本,也必须是已发布版本!这使得我们可以在 CI 构建中使用已发布的jaxlib
wheel 包,并允许 Python 开发者在 HEAD 上开发jax
而无需构建jaxlib
。例如,要移除
jax
Python 代码中旧的向后兼容路径,只需提高最低 jaxlib 版本,然后删除兼容路径即可。jaxlib
可以放弃对其自身发布版本号更低的旧jax
版本的兼容性。jax
强制执行的版本约束将禁止使用不兼容的jaxlib
。例如,如果
jaxlib
要放弃旧jax
版本使用的 Python 绑定 API,则必须递增jaxlib
的次版本号或主版本号。如果可能,对
jaxlib
的更改应以向后兼容的方式进行。通常情况下,
jaxlib
可以自由更改其 API,只要遵循jax
兼容所有至少与最低版本一样新的jaxlib
的规则即可。这意味着jax
必须始终兼容至少两个版本的jaxlib
,即上一个发布版本和树尖版本(实际上是下一个发布版本)。如果保持兼容性,这样做会更容易,尽管也可以使用jax
的版本测试进行不兼容的更改;详见下文。例如,向
jaxlib
添加新函数通常是安全的,但如果当前的jax
仍在 F使用该函数,则移除现有函数或更改其签名是不安全的。对jax
的更改必须适用于所有大于最低版本直到 HEAD 的jaxlib
发布版本,或者在此类版本中优雅降级。
请注意,这里的兼容性规则仅适用于 *已发布* 版本的 jax
和 jaxlib
。它们不适用于未发布版本;也就是说,如果 jaxlib
中的某个 API 从未发布,或者没有已发布的 jax
版本使用该 API,则可以引入然后移除该 API。
jaxlib
的源代码是如何布局的?#
jaxlib
分布在两个主要仓库中,即主 JAX 仓库中的 jaxlib/
子目录和 XLA 仓库中的 XLA 源代码树。XLA 中 JAX 特定的部分主要在 xla/python
子目录中。
JAX 的 C++ 部分(例如 Python 绑定和运行时组件)位于 XLA 树中的原因,部分是历史原因,部分是技术原因。
历史原因是,最初 xla/python
绑定被设想为通用的 Python 绑定,可能会与其他框架共享。实际上,这种情况越来越少见,xla/python
包含了许多 JAX 特定的部分,并且可能会包含更多。因此,最好简单地将 xla/python
视为 JAX 的一部分。
技术原因是 XLA C++ API 不稳定。通过将 XLA:Python 绑定保留在 XLA 树中,它们的 C++ 实现可以与 XLA 的 C++ API 原子更新。Python API 比 C++ API 更容易维护向后和向前兼容性,因此 xla/python
暴露 Python API 并负责维护 Python 级别的向后兼容性。
jaxlib
是使用 Bazel 从 jax
仓库中构建的。XLA 仓库中的 jaxlib
部分作为 Bazel 子模块被纳入构建。要更新构建过程中使用的 XLA 版本,必须更新 Bazel WORKSPACE
中的固定版本。这是根据需要手动完成的,但可以在每次构建时覆盖。
在发布之间,我们如何跨 jax
和 jaxlib
边界进行更改?#
jaxlib 版本是一个粗略的工具:它只允许我们推断 *发布版本*。
然而,由于 jax
和 jaxlib
代码分散在无法通过一次更改进行原子更新的仓库中,我们需要以比发布周期更精细的粒度管理兼容性。为了管理细粒度兼容性,我们有独立于 jaxlib
发布版本号的额外版本控制。
我们在XLA 仓库中的 xla_client.py
中维护了一个额外的版本号(_version
)。其思想是,此版本号在 xla/python
中与 JAX 的 C++ 部分一起定义,并且 JAX Python 也可以通过 jax._src.lib.jaxlib_extension_version
访问,并且每次对 XLA/Python 代码进行具有向后兼容性影响的更改时都必须递增。JAX Python 代码随后可以使用此版本号来维护向后兼容性,例如:
from jax._src.lib import jaxlib_extension_version
# 123 is the new version number for _version in xla_client.py
if jaxlib_extension_version >= 123:
# Use new code path
...
else:
# Use old code path.
请注意,此版本号是除了对已发布版本号的约束 *之外* 的。也就是说,此版本号旨在帮助管理未发布代码在开发过程中的兼容性。发布版本也必须遵循上述兼容性规则。