安装#

使用 JAX 需要安装两个软件包:jax (纯 Python 且跨平台) 和 jaxlib (包含编译好的二进制文件,需要针对不同的操作系统和加速器进行不同构建)。

总结: 对于大多数用户而言,典型的 JAX 安装可能如下所示

  • 仅限 CPU (Linux/macOS/Windows)

    pip install -U jax
    
  • GPU (NVIDIA, CUDA 12)

    pip install -U "jax[cuda12]"
    
  • TPU (Google Cloud TPU 虚拟机)

    pip install -U "jax[tpu]"
    

支持的平台#

下表列出了所有支持的平台和安装选项。请检查您的设置是否受支持;如果显示“是”或“实验性”,请点击相应的链接以了解 JAX 的详细安装方法。

Linux, x86_64

Linux, aarch64

Mac, aarch64

Windows, x86_64

Windows WSL2, x86_64

CPU

NVIDIA GPU

不适用

实验性

Google Cloud TPU

不适用

不适用

不适用

不适用

AMD GPU

不适用

Apple GPU

不适用

实验性

不适用

不适用

Intel GPU

实验性

不适用

不适用

CPU#

pip 安装: CPU#

目前,JAX 团队为以下操作系统和架构发布 jaxlib wheels

  • Linux, x86_64

  • Linux, aarch64

  • macOS, Apple 基于 ARM

  • Windows, x86_64 (实验性)

要安装仅限 CPU 版本的 JAX (这对于在笔记本电脑上进行本地开发可能很有用),您可以运行

pip install --upgrade pip
pip install --upgrade jax

在 Windows 上,如果您的机器尚未安装 Microsoft Visual Studio 2019 Redistributable,您可能还需要安装它。

其他操作系统和架构需要从源代码构建。尝试在其他操作系统和架构上使用 pip 安装可能导致 jaxlib 未与 jax 一起安装,尽管 jax 可能成功安装 (但在运行时会失败)。

NVIDIA GPU#

JAX 支持 SM 版本 5.2 (Maxwell) 或更高版本的 NVIDIA GPU。请注意,自 NVIDIA 在其软件中停止支持 Kepler GPU 以来,JAX 已不再支持 Kepler 系列 GPU。

您必须首先安装 NVIDIA 驱动程序。建议您安装 NVIDIA 提供的最新驱动程序,但对于 Linux 上的 CUDA 12,驱动程序版本必须 >= 525.60.13。

如果您需要使用较新的 CUDA 工具包和较旧的驱动程序 (例如在无法轻松更新 NVIDIA 驱动程序的集群上),您可以使用 NVIDIA 为此目的提供的 CUDA 向前兼容性软件包

pip 安装: NVIDIA GPU (CUDA,通过 pip 安装,更简单)#

安装支持 NVIDIA GPU 的 JAX 有两种方式

  • 使用从 pip wheels 安装的 NVIDIA CUDA 和 cuDNN

  • 使用自行安装的 CUDA/cuDNN

JAX 团队强烈建议使用 pip wheels 安装 CUDA 和 cuDNN,因为这要简单得多!

NVIDIA 仅为 x86_64 和 aarch64 发布了 CUDA pip 软件包;在其他平台上,您必须使用本地安装的 CUDA。

pip install --upgrade pip

# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12]"

如果 JAX 检测到 NVIDIA CUDA 库的版本错误,您需要检查以下几点

  • 确保未设置 LD_LIBRARY_PATH,因为 LD_LIBRARY_PATH 可能会覆盖 NVIDIA CUDA 库。

  • 确保安装的 NVIDIA CUDA 库是 JAX 所需的版本。重新运行上述安装命令应该可以解决问题。

pip 安装: NVIDIA GPU (CUDA,本地安装,更复杂)#

如果您更喜欢使用预安装的 NVIDIA CUDA,则必须首先安装 NVIDIA CUDAcuDNN

JAX 仅为 Linux x86_64 和 Linux aarch64 提供预构建的 CUDA 兼容 wheels。其他操作系统和架构组合也是可能的,但需要从源代码构建 (请参阅 从源代码构建 以了解更多信息)。

您应该使用至少与您的 NVIDIA CUDA 工具包对应驱动版本一样新的 NVIDIA 驱动程序。如果您需要使用较新的 CUDA 工具包和较旧的驱动程序 (例如在无法轻松更新 NVIDIA 驱动程序的集群上),您可以使用 NVIDIA 为此目的提供的 CUDA 向前兼容性软件包

JAX 目前提供一种 CUDA wheel 变体

构建版本

兼容版本

CUDA 12.3

CUDA >=12.1

CUDNN 9.1

CUDNN >=9.1, <10.0

NCCL 2.19

NCCL >=2.18

JAX 会检查您的库版本,如果版本不够新,将报告错误。设置 JAX_SKIP_CUDA_CONSTRAINTS_CHECK 环境变量将禁用此检查,但使用旧版 CUDA 可能会导致错误或不正确的结果。

NCCL 是一个可选依赖项,仅当您执行多 GPU 计算时才需要。

要安装,请运行

pip install --upgrade pip

# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12-local]"

这些 pip 安装不适用于 Windows,并且可能静默失败;请参阅上方表格。

您可以使用以下命令查找您的 CUDA 版本

nvcc --version

JAX 使用 LD_LIBRARY_PATH 查找 CUDA 库,使用 PATH 查找二进制文件 (ptxas, nvlink)。请确保这些路径指向正确的 CUDA 安装。

JAX 需要 libdevice10.bc,它通常来自 cuda-nvvm 软件包。请确保它存在于您的 CUDA 安装中。

如果您在使用预构建 wheels 时遇到任何错误或问题,请在 GitHub issue tracker 上告知 JAX 团队。

NVIDIA GPU Docker 容器#

NVIDIA 提供了 JAX Toolbox 容器,这些是最新的容器,包含 JAX 的夜间版本以及一些模型/框架。

Google Cloud TPU#

pip 安装: Google Cloud TPU#

JAX 为 Google Cloud TPU 提供预构建的 wheels。要在您的云 TPU 虚拟机中安装 JAX 以及相应版本的 jaxliblibtpu,您可以运行以下命令

pip install "jax[tpu]"

对于 Colab (https://colab.research.google.com/) 用户,请确保您使用的是 TPU v2,而不是旧的、已弃用的 TPU 运行时。

Mac GPU#

pip 安装#

Apple 提供了一个实验性的 Metal 插件。有关详细信息,请参阅 Apple 的 JAX on Metal 文档

注意: Metal 插件有几个注意事项

  • Metal 插件是新的实验性功能,并且存在许多已知问题。请在 JAX issue tracker 上报告任何问题。

  • Metal 插件目前需要非常特定版本的 jaxjaxlib。随着插件 API 的成熟,此限制将逐渐放宽。

AMD GPU (Linux)#

AMD GPU 支持由 AMD 支持的 ROCm JAX 插件提供。

在 AMDGPU 设备上使用 JAX 有多种方法。请参阅 AMD 的说明 以获取详细信息。

Intel GPU#

Intel 提供了一个实验性的 OneAPI 插件:intel-extension-for-openxla,用于 Intel GPU 硬件。有关更多详细信息和安装说明,请参考以下两种方法之一

  1. Pip 安装: 在 Intel GPU 上加速 JAX

  2. 使用 Intel 的 XLA Docker 容器

请报告任何与以下内容相关的问题

Conda (社区支持)#

Conda 安装#

有一个社区支持的 jax Conda 构建版本。要使用 conda 安装它,只需运行

conda install jax -c conda-forge

如果您在配备 NVIDIA GPU 的机器上运行此命令,这将安装一个支持 CUDA 的 jaxlib 软件包。

为确保您安装的 JAX 版本确实支持 CUDA,请运行

conda install "jaxlib=*=*cuda*" jax -c conda-forge

如果您想覆盖 JAX 使用的 CUDA 版本,或者在没有 GPU 的机器上安装 CUDA 构建版本,请遵循 conda-forge 网站 Tips & tricks 部分中的说明。

有关更多详细信息,请访问 conda-forgejaxlibjax 仓库。

JAX 夜间版本安装#

夜间版本反映了构建时 JAX 主仓库的状态,并且可能无法通过完整的测试套件。

与安装 JAX 发布版本的说明不同,这里我们在命令行上明确指定所有 JAX 软件包的名称,以便 pip 在有新版本可用时进行升级。

JAX 将夜间版本、发布候选版本 (RC) 和发布版本发布到多个非 PyPI 的 PEP 503 索引。

所有 JAX 软件包都可以通过索引 https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ 以及 PyPI 镜像包进行访问。这种额外的镜像功能使得夜间版本安装可以使用 pip 的 --index (-i) 作为安装方法。

注意: 即使在发布后、最新夜间版本重建之前,统一索引也可能在 --pre 选项下返回 RC 或发布版本作为最新版本。如果必须针对夜间版本进行自动化或测试,或者您无法使用我们的完整索引,请使用只包含夜间构建件的额外索引 https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/

  • 仅限 CPU

pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
  • Google Cloud TPU

pip install -U --pre jax jaxlib libtpu requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  • NVIDIA GPU (CUDA 12)

pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
  • NVIDIA GPU (CUDA 12) 传统版本

以下用于历史性的整体式 CUDA jaxlibs 夜间版本。您很可能不需要这个;未来不会再构建整体式 CUDA jaxlibs,并且现有的将在 2024 年 9 月到期。请使用上面的“CUDA 12”选项。

pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html

从源代码构建 JAX#

请参阅 从源代码构建

安装旧版 jaxlib wheels#

由于 Python 包索引上的存储限制,JAX 团队会定期从 https://pypi.ac.cn/project/jax 上的发布版本中移除旧的 jaxlib wheels。这些仍然可以通过此处的 URL 直接安装。例如

# Install jaxlib on CPU via the wheel archive
pip install "jax[cpu]==0.3.25" -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

# Install the jaxlib 0.3.25 CPU wheel directly
pip install jaxlib==0.3.25 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

对于特定的旧版 GPU wheels,请务必使用 jax_cuda_releases.html URL;例如

pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html