安装#
使用 JAX 需要安装两个软件包:jax
(纯 Python 且跨平台) 和 jaxlib
(包含编译好的二进制文件,需要针对不同的操作系统和加速器进行不同构建)。
总结: 对于大多数用户而言,典型的 JAX 安装可能如下所示
仅限 CPU (Linux/macOS/Windows)
pip install -U jax
GPU (NVIDIA, CUDA 12)
pip install -U "jax[cuda12]"
TPU (Google Cloud TPU 虚拟机)
pip install -U "jax[tpu]"
支持的平台#
下表列出了所有支持的平台和安装选项。请检查您的设置是否受支持;如果显示“是”或“实验性”,请点击相应的链接以了解 JAX 的详细安装方法。
Linux, x86_64 |
Linux, aarch64 |
Mac, aarch64 |
Windows, x86_64 |
Windows WSL2, x86_64 |
|
---|---|---|---|---|---|
CPU |
|||||
NVIDIA GPU |
不适用 |
否 |
|||
Google Cloud TPU |
不适用 |
不适用 |
不适用 |
不适用 |
|
AMD GPU |
否 |
不适用 |
否 |
否 |
|
Apple GPU |
不适用 |
否 |
不适用 |
不适用 |
|
Intel GPU |
不适用 |
不适用 |
否 |
否 |
CPU#
pip 安装: CPU#
目前,JAX 团队为以下操作系统和架构发布 jaxlib
wheels
Linux, x86_64
Linux, aarch64
macOS, Apple 基于 ARM
Windows, x86_64 (实验性)
要安装仅限 CPU 版本的 JAX (这对于在笔记本电脑上进行本地开发可能很有用),您可以运行
pip install --upgrade pip
pip install --upgrade jax
在 Windows 上,如果您的机器尚未安装 Microsoft Visual Studio 2019 Redistributable,您可能还需要安装它。
其他操作系统和架构需要从源代码构建。尝试在其他操作系统和架构上使用 pip 安装可能导致 jaxlib
未与 jax
一起安装,尽管 jax
可能成功安装 (但在运行时会失败)。
NVIDIA GPU#
JAX 支持 SM 版本 5.2 (Maxwell) 或更高版本的 NVIDIA GPU。请注意,自 NVIDIA 在其软件中停止支持 Kepler GPU 以来,JAX 已不再支持 Kepler 系列 GPU。
您必须首先安装 NVIDIA 驱动程序。建议您安装 NVIDIA 提供的最新驱动程序,但对于 Linux 上的 CUDA 12,驱动程序版本必须 >= 525.60.13。
如果您需要使用较新的 CUDA 工具包和较旧的驱动程序 (例如在无法轻松更新 NVIDIA 驱动程序的集群上),您可以使用 NVIDIA 为此目的提供的 CUDA 向前兼容性软件包。
pip 安装: NVIDIA GPU (CUDA,通过 pip 安装,更简单)#
安装支持 NVIDIA GPU 的 JAX 有两种方式
使用从 pip wheels 安装的 NVIDIA CUDA 和 cuDNN
使用自行安装的 CUDA/cuDNN
JAX 团队强烈建议使用 pip wheels 安装 CUDA 和 cuDNN,因为这要简单得多!
NVIDIA 仅为 x86_64 和 aarch64 发布了 CUDA pip 软件包;在其他平台上,您必须使用本地安装的 CUDA。
pip install --upgrade pip
# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12]"
如果 JAX 检测到 NVIDIA CUDA 库的版本错误,您需要检查以下几点
确保未设置
LD_LIBRARY_PATH
,因为LD_LIBRARY_PATH
可能会覆盖 NVIDIA CUDA 库。确保安装的 NVIDIA CUDA 库是 JAX 所需的版本。重新运行上述安装命令应该可以解决问题。
pip 安装: NVIDIA GPU (CUDA,本地安装,更复杂)#
如果您更喜欢使用预安装的 NVIDIA CUDA,则必须首先安装 NVIDIA CUDA 和 cuDNN。
JAX 仅为 Linux x86_64 和 Linux aarch64 提供预构建的 CUDA 兼容 wheels。其他操作系统和架构组合也是可能的,但需要从源代码构建 (请参阅 从源代码构建 以了解更多信息)。
您应该使用至少与您的 NVIDIA CUDA 工具包对应驱动版本一样新的 NVIDIA 驱动程序。如果您需要使用较新的 CUDA 工具包和较旧的驱动程序 (例如在无法轻松更新 NVIDIA 驱动程序的集群上),您可以使用 NVIDIA 为此目的提供的 CUDA 向前兼容性软件包。
JAX 目前提供一种 CUDA wheel 变体
构建版本 |
兼容版本 |
---|---|
CUDA 12.3 |
CUDA >=12.1 |
CUDNN 9.1 |
CUDNN >=9.1, <10.0 |
NCCL 2.19 |
NCCL >=2.18 |
JAX 会检查您的库版本,如果版本不够新,将报告错误。设置 JAX_SKIP_CUDA_CONSTRAINTS_CHECK
环境变量将禁用此检查,但使用旧版 CUDA 可能会导致错误或不正确的结果。
NCCL 是一个可选依赖项,仅当您执行多 GPU 计算时才需要。
要安装,请运行
pip install --upgrade pip
# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12-local]"
这些 pip
安装不适用于 Windows,并且可能静默失败;请参阅上方表格。
您可以使用以下命令查找您的 CUDA 版本
nvcc --version
JAX 使用 LD_LIBRARY_PATH
查找 CUDA 库,使用 PATH
查找二进制文件 (ptxas
, nvlink
)。请确保这些路径指向正确的 CUDA 安装。
JAX 需要 libdevice10.bc,它通常来自 cuda-nvvm 软件包。请确保它存在于您的 CUDA 安装中。
如果您在使用预构建 wheels 时遇到任何错误或问题,请在 GitHub issue tracker 上告知 JAX 团队。
NVIDIA GPU Docker 容器#
NVIDIA 提供了 JAX Toolbox 容器,这些是最新的容器,包含 JAX 的夜间版本以及一些模型/框架。
Google Cloud TPU#
pip 安装: Google Cloud TPU#
JAX 为 Google Cloud TPU 提供预构建的 wheels。要在您的云 TPU 虚拟机中安装 JAX 以及相应版本的 jaxlib
和 libtpu
,您可以运行以下命令
pip install "jax[tpu]"
对于 Colab (https://colab.research.google.com/) 用户,请确保您使用的是 TPU v2,而不是旧的、已弃用的 TPU 运行时。
Mac GPU#
pip 安装#
Apple 提供了一个实验性的 Metal 插件。有关详细信息,请参阅 Apple 的 JAX on Metal 文档。
注意: Metal 插件有几个注意事项
Metal 插件是新的实验性功能,并且存在许多已知问题。请在 JAX issue tracker 上报告任何问题。
Metal 插件目前需要非常特定版本的
jax
和jaxlib
。随着插件 API 的成熟,此限制将逐渐放宽。
AMD GPU (Linux)#
AMD GPU 支持由 AMD 支持的 ROCm JAX 插件提供。
在 AMDGPU 设备上使用 JAX 有多种方法。请参阅 AMD 的说明 以获取详细信息。
Intel GPU#
Intel 提供了一个实验性的 OneAPI 插件:intel-extension-for-openxla,用于 Intel GPU 硬件。有关更多详细信息和安装说明,请参考以下两种方法之一
Pip 安装: 在 Intel GPU 上加速 JAX。
请报告任何与以下内容相关的问题
JAX: JAX issue tracker。
Intel 的 OpenXLA 插件: Intel-extension-for-openxla issue tracker。
Conda (社区支持)#
Conda 安装#
有一个社区支持的 jax
Conda 构建版本。要使用 conda
安装它,只需运行
conda install jax -c conda-forge
如果您在配备 NVIDIA GPU 的机器上运行此命令,这将安装一个支持 CUDA 的 jaxlib
软件包。
为确保您安装的 JAX 版本确实支持 CUDA,请运行
conda install "jaxlib=*=*cuda*" jax -c conda-forge
如果您想覆盖 JAX 使用的 CUDA 版本,或者在没有 GPU 的机器上安装 CUDA 构建版本,请遵循 conda-forge
网站 Tips & tricks 部分中的说明。
JAX 夜间版本安装#
夜间版本反映了构建时 JAX 主仓库的状态,并且可能无法通过完整的测试套件。
与安装 JAX 发布版本的说明不同,这里我们在命令行上明确指定所有 JAX 软件包的名称,以便 pip
在有新版本可用时进行升级。
JAX 将夜间版本、发布候选版本 (RC) 和发布版本发布到多个非 PyPI 的 PEP 503 索引。
所有 JAX 软件包都可以通过索引 https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
以及 PyPI 镜像包进行访问。这种额外的镜像功能使得夜间版本安装可以使用 pip 的 --index
(-i
) 作为安装方法。
注意: 即使在发布后、最新夜间版本重建之前,统一索引也可能在 --pre
选项下返回 RC 或发布版本作为最新版本。如果必须针对夜间版本进行自动化或测试,或者您无法使用我们的完整索引,请使用只包含夜间构建件的额外索引 https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
。
仅限 CPU
pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
Google Cloud TPU
pip install -U --pre jax jaxlib libtpu requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
NVIDIA GPU (CUDA 12)
pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
NVIDIA GPU (CUDA 12) 传统版本
以下用于历史性的整体式 CUDA jaxlibs 夜间版本。您很可能不需要这个;未来不会再构建整体式 CUDA jaxlibs,并且现有的将在 2024 年 9 月到期。请使用上面的“CUDA 12”选项。
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
从源代码构建 JAX#
请参阅 从源代码构建。
安装旧版 jaxlib
wheels#
由于 Python 包索引上的存储限制,JAX 团队会定期从 https://pypi.ac.cn/project/jax 上的发布版本中移除旧的 jaxlib
wheels。这些仍然可以通过此处的 URL 直接安装。例如
# Install jaxlib on CPU via the wheel archive
pip install "jax[cpu]==0.3.25" -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
# Install the jaxlib 0.3.25 CPU wheel directly
pip install jaxlib==0.3.25 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
对于特定的旧版 GPU wheels,请务必使用 jax_cuda_releases.html
URL;例如
pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html