安装#

使用 JAX 需要安装两个包:jax,它是纯 Python 且跨平台,以及 jaxlib,它包含编译好的二进制文件,并需要针对不同的操作系统和加速器进行不同的构建。

总结: 对于大多数用户,典型的 JAX 安装可能如下所示

  • 仅 CPU (Linux/macOS/Windows)

    pip install -U jax
    
  • GPU (NVIDIA, CUDA 13)

    pip install -U "jax[cuda13]"
    
  • TPU (Google Cloud TPU VM)

    pip install -U "jax[tpu]"
    

支持的平台#

下表显示了所有支持的平台和安装选项。检查您的设置是否受支持;如果显示“是”“实验性”,请点击相应的链接以了解如何更详细地安装 JAX。

Linux, x86_64

Linux, aarch64

Mac, aarch64

Windows, x86_64

Windows WSL2, x86_64

CPU

NVIDIA GPU

不适用

实验性

Google Cloud TPU

不适用

不适用

不适用

不适用

AMD GPU

不适用

实验性

Apple GPU

不适用

实验性

不适用

不适用

Intel GPU

实验性

不适用

不适用

CPU#

pip 安装:CPU#

目前,JAX 团队为以下操作系统和架构发布 jaxlib wheels

  • Linux, x86_64

  • Linux, aarch64

  • macOS, Apple ARM 架构

  • Windows, x86_64 (实验性)

要安装仅 CPU 版本的 JAX,这可能有助于在笔记本电脑上进行本地开发,您可以运行

pip install --upgrade pip
pip install --upgrade jax

在 Windows 上,如果您的计算机上尚未安装,您可能还需要安装 Microsoft Visual Studio 2019 可再发行组件

其他操作系统和架构需要从源代码构建。尝试在其他操作系统和架构上使用 pip 安装可能会导致 jaxlib 未与 jax 一起安装,尽管 jax 可能会成功安装(但在运行时失败)。

NVIDIA GPU#

在 CUDA 12 上,JAX 支持 SM 版本 5.2 (Maxwell) 或更高版本的 NVIDIA GPU。请注意,由于 NVIDIA 已在软件中停止支持 Kepler GPU,JAX 不再支持 Kepler 系列 GPU。在 CUDA 13 上,JAX 支持 SM 版本 7.5 或更高版本的 NVIDIA GPU。NVIDIA 在 CUDA 13 中停止了对先前 GPU 的支持。

您必须首先安装 NVIDIA 驱动程序。建议您安装 NVIDIA 提供的最新驱动程序,但驱动程序版本必须是 Linux 上 CUDA 12 的 >= 525,以及 Linux 上 CUDA 13 的 >= 580。

如果您需要使用较新的 CUDA 工具包但驱动程序较旧,例如在无法轻松更新 NVIDIA 驱动程序的集群上,您可以使用 NVIDIA 提供的 CUDA 前向兼容包

pip 安装:NVIDIA GPU (CUDA,通过 pip 安装,更简单)#

有两种方法可以安装支持 NVIDIA GPU 的 JAX

  • 使用从 pip wheels 安装的 NVIDIA CUDA 和 cuDNN

  • 使用自安装的 CUDA/cuDNN

JAX 团队强烈建议使用 pip wheels 安装 CUDA 和 cuDNN,因为这要容易得多!

NVIDIA 已发布仅适用于 x86_64 和 aarch64 的 CUDA 包。

pip install --upgrade pip

# NVIDIA CUDA 13 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda13]"

# Alternatively, for CUDA 12, use
# pip install --upgrade "jax[cuda12]"

我们建议迁移到 CUDA 13 wheels;将来某个时候我们将停止支持 CUDA 12。

如果 JAX 检测到错误版本的 NVIDIA CUDA 库,您需要检查以下几项

  • 确保 LD_LIBRARY_PATH 未设置,因为 LD_LIBRARY_PATH 会覆盖 NVIDIA CUDA 库。

  • 确保安装的 NVIDIA CUDA 库是 JAX 所需的。重新运行上面的安装命令应该可以正常工作。

pip 安装:NVIDIA GPU (CUDA,本地安装,更复杂)#

如果您希望使用预先安装的 NVIDIA CUDA 副本,您必须先安装 NVIDIA CUDAcuDNN

JAX 提供了适用于仅 Linux x86_64 和 Linux aarch64 的预构建 CUDA 兼容 wheels。其他操作系统和架构的组合也是可能的,但需要从源代码构建(有关详细信息,请参阅 从源代码构建)。

您应该使用与您的 NVIDIA CUDA 工具包相应驱动程序版本 至少一样新的 NVIDIA 驱动程序版本。如果您需要使用较新的 CUDA 工具包但驱动程序较旧,例如在无法轻松更新 NVIDIA 驱动程序的集群上,您可以使用 NVIDIA 提供的 CUDA 前向兼容包

JAX 目前提供两个 CUDA wheel 变体:CUDA 12 和 CUDA 13

CUDA 12 wheel 是

构建于

兼容于

CUDA 12.3

CUDA >=12.1

CUDNN 9.8

CUDNN >=9.8, <10.0

NCCL 2.19

NCCL >=2.18

CUDA 13 wheel 是

构建于

兼容于

CUDA 13.0

CUDA >=13.0

CUDNN 9.8

CUDNN >=9.8, <10.0

NCCL 2.19

NCCL >=2.18

JAX 会检查您库的版本,如果版本不够新,会报错。设置 JAX_SKIP_CUDA_CONSTRAINTS_CHECK 环境变量将禁用检查,但使用旧版本的 CUDA 可能会导致错误或不正确的结果。

NCCL 是一个可选依赖项,仅在您执行多 GPU 计算时需要。

要安装,请运行

pip install --upgrade pip


# Installs the wheel compatible with NVIDIA CUDA 13 and cuDNN 9.8 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda13-local]"

# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.8 or newer.
# Note: wheels only available on linux.
# pip install --upgrade "jax[cuda12-local]"

这些 pip 安装不适用于 Windows,并且可能会静默失败;请参阅上方的表格。

您可以使用以下命令查找您的 CUDA 版本

nvcc --version

JAX 使用 LD_LIBRARY_PATH 来查找 CUDA 库,并使用 PATH 来查找二进制文件 (ptxas, nvlink)。请确保这些路径指向正确的 CUDA 安装。

JAX 需要 libdevice10.bc,它通常来自 cuda-nvvm 包。请确保它存在于您的 CUDA 安装中。

如果您在使用预构建 wheels 时遇到任何错误或问题,请在 GitHub issue tracker 上告知 JAX 团队。

NVIDIA GPU Docker 容器#

NVIDIA 提供了 JAX Toolbox 容器,这些容器是包含 JAX 的 nightly 版本和一些模型/框架的最新容器。

Google Cloud TPU#

pip 安装:Google Cloud TPU#

JAX 提供了适用于 Google Cloud TPU 的预构建 wheels。要将 JAX 与相应版本的 jaxliblibtpu 一起安装,您可以在您的 cloud TPU VM 中运行以下命令

pip install "jax[tpu]"

对于 Colab (https://colab.research.google.com/) 用户,请确保您使用的是TPU v2 而不是旧的、已弃用的 TPU 运行时。

Mac GPU#

pip 安装#

Apple 提供了一个实验性的 Metal 插件。有关详细信息,请参阅 Apple 的 JAX on Metal 文档

注意: Metal 插件有几个需要注意的地方

  • Metal 插件是新的且处于实验阶段,并且有许多已知问题。请在 JAX issue tracker 上报告任何问题。

  • Metal 插件目前需要非常特定版本的 jaxjaxlib。随着插件 API 的成熟,这种限制将随着时间的推移而放宽。

AMD GPU (Linux)#

AMD GPU 支持由 AMD 支持的 ROCm JAX 插件提供。

有几种方法可以在 AMDGPU 设备上使用 JAX。请参阅 AMD 的说明以获取详细信息。

注意:Windows WSL2 上的 ROCm 支持处于实验阶段。对于 WSL 安装,您可能需要

  1. 按照 AMD 的官方指南安装 ROCm for WSL

  2. 在您的 WSL 环境中遵循标准的 Linux ROCm JAX 安装步骤

  3. 请注意,性能和稳定性可能与原生 Linux 安装有所不同

Intel GPU#

Intel 提供了一个实验性的 OneAPI 插件:intel-extension-for-openxla,用于 Intel GPU 硬件。有关更多详细信息和安装说明,请参阅以下两种方法之一

  1. Pip 安装:JAX 在 Intel GPU 上的加速

  2. 使用 Intel 的 XLA Docker 容器

请报告任何与以下相关的 issues

Conda (社区支持)#

Conda 安装#

有一个社区支持的 jax Conda 构建。要使用 conda 安装它,只需运行

conda install jax -c conda-forge

如果您在带有 NVIDIA GPU 的机器上运行此命令,它应该会安装 jaxlib 的 CUDA 兼容包。

要确保您安装的 jax 版本确实启用了 CUDA,请运行

conda install "jaxlib=*=*cuda*" jax -c conda-forge

如果您想覆盖 JAX 使用的 CUDA 版本,或者在没有 GPU 的机器上安装 CUDA 构建,请按照 conda-forge 网站的技巧与窍门部分中的说明进行操作。

前往 conda-forgejaxlibjax 存储库以获取更多详细信息。

JAX 日构建安装#

日构建版本反映了 JAX 主存储库在构建时的状态,可能无法通过完整的测试套件。

与安装 JAX 发行版的说明不同,在这里我们在命令行中明确命名 JAX 的所有包,因此 pip 会在有新版本可用时升级它们。

JAX 将 nightly 版本、候选发布版 (RC) 和发布版发布到几个非 pypi 的 PEP 503 索引。

所有 JAX 包都可以从索引 https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ 以及 PyPI 镜像包中找到。这个额外的镜像使得 nightly 安装可以使用 –index (-i) 作为 pip 的安装方法。

注意: 即使在 nightly 版本重建之前使用了 --pre,统一索引也可能返回 RC 或发布版本作为最新版本。如果自动化或测试必须针对 nightly 版本进行,或者您不能使用我们的完整索引,请使用额外的索引 https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/,它只包含 nightly 伪像。

  • 仅 CPU

pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
  • Google Cloud TPU

pip install -U --pre jax jaxlib libtpu requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  • NVIDIA GPU (CUDA 13)

pip install -U --pre jax jaxlib "jax-cuda13-plugin[with-cuda]" jax-cuda13-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
  • NVIDIA GPU (CUDA 12)

pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

从源代码构建 JAX#

请参阅 从源代码构建

安装旧版 jaxlib wheels#

由于 Python 包索引的存储限制,JAX 团队会定期从 https://pypi.ac.cn/project/jax 上的发布版本中移除旧的 jaxlib wheels。这些仍然可以直接通过这里的 URL 安装。例如

# Install jaxlib on CPU via the wheel archive
pip install "jax[cpu]==0.3.25" -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

# Install the jaxlib 0.3.25 CPU wheel directly
pip install jaxlib==0.3.25 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

对于特定的旧 GPU wheels,请务必使用 jax_cuda_releases.html URL;例如

pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html