从源代码构建#

首先,获取 JAX 源代码

git clone https://github.com/jax-ml/jax
cd jax

构建 JAX 涉及两个步骤

  1. 构建或安装 jaxlib,它是 jax 的 C++ 支持库。

  2. 安装 jax Python 包。

构建或安装 jaxlib#

使用 pip 安装 jaxlib#

如果您只修改 JAX 的 Python 部分,我们建议使用 pip 从预构建的 wheel 安装 jaxlib

pip install jaxlib

有关 pip 安装的完整指南(例如,对于 GPU 和 TPU 支持),请参阅 JAX 自述文件

从源代码构建 jaxlib#

警告

虽然通常可以使用大多数现代编译器从源代码编译 jaxlib,但构建仅使用 clang 进行测试。欢迎提交拉取请求以改进对不同工具链的支持,但其他编译器不会得到积极支持。

要从源代码构建 jaxlib,您还必须安装一些先决条件

  • C++ 编译器

    如上框所述,最好使用最新版本的 clang(在撰写本文时,我们测试的版本是 18),但其他编译器(例如 g++ 或 MSVC)也可能有效。

    在 Ubuntu 或 Debian 上,您可以按照 LLVM 文档中的说明安装最新稳定版本的 clang。

    如果您在 Mac 上进行构建,请确保已安装 XCode 和 XCode 命令行工具。

    有关 Windows 构建说明,请参见下文。

  • Python:用于运行构建辅助脚本。请注意,无需在本地安装 Python 依赖项,因为您的系统 Python 将在构建期间被忽略;请查看 管理密封的 Python 获取详细信息。

要为 CPU 或 TPU 构建 jaxlib,您可以运行

python build/build.py build --wheels=jaxlib --verbose
pip install dist/*.whl  # installs jaxlib (includes XLA)

要为与当前系统安装的 Python 版本不同的 Python 版本构建 wheel,请将 --python_version 标志传递给 build 命令

python build/build.py build --wheels=jaxlib --python_version=3.12 --verbose

本文档的其余部分假设您正在为与当前系统安装匹配的 Python 版本进行构建。如果您需要为不同的版本进行构建,只需在每次调用 python build/build.py 时附加 --python_version=<py version> 标志。请注意,无论是否传递 --python_version 参数,Bazel 构建始终会使用密封的 Python 安装。

如果您想构建 jaxlib 和 CUDA 插件:请运行

python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt

生成三个 wheel(不带 cuda 的 jaxlib、jax-cuda-plugin 和 jax-cuda-pjrt)。默认情况下,所有 CUDA 编译步骤都由 NVCC 和 clang 执行,但可以通过 --build_cuda_with_clang 标志将其限制为 clang。

有关配置选项,请参阅 python build/build.py --help。这里的 python 应该是您的 Python 3 解释器的名称;在某些系统上,您可能需要使用 python3。尽管使用 python 调用脚本,Bazel 将始终使用其自己的密封 Python 解释器和依赖项,只有 build/build.py 脚本本身将由您的系统 Python 解释器处理。默认情况下,wheel 被写入当前目录的 dist/ 子目录。

  • 从 v.0.4.32 开始的 JAX 版本:您可以在配置选项中提供自定义的 CUDA 和 CUDNN 版本。Bazel 将下载它们并用作目标依赖项。

    要下载特定版本的 CUDA/CUDNN 再分发包,您可以使用 --cuda_version--cudnn_version 标志

    python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 \
    --cudnn_version=9.1.1
    

    python build/build.py build --wheels=jax-cuda-pjrt --cuda_version=12.3.2 \
    --cudnn_version=9.1.1
    

    请注意,这些参数是可选的:默认情况下,Bazel 将下载环境变量 HERMETIC_CUDA_VERSIONHERMETIC_CUDNN_VERSION.bazelrc 中提供的 CUDA 和 CUDNN 再分发版本。

    要指向本地文件系统上的 CUDA/CUDNN/NCCL 再分发包,您可以使用以下命令

    python build/build.py build --wheels=jax-cuda-plugin \
    --bazel_options=--repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" \
    --bazel_options=--repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" \
    --bazel_options=--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl"
    

    请参阅 XLA 文档中的完整说明列表。

  • v.0.4.32 之前的 JAX 版本:您必须安装 CUDA 和 CUDNN,并使用配置选项提供它们的路径。

从修改后的 XLA 存储库构建 jaxlib。#

JAX 依赖于 XLA,其源代码位于 XLA GitHub 存储库中。默认情况下,JAX 使用 XLA 存储库的固定副本,但我们在处理 JAX 时经常希望使用本地修改的 XLA 副本。有两种方法可以做到这一点

  • 使用 Bazel 的 override_repository 功能,您可以将其作为命令行标志传递给 build.py,如下所示

    python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla
    
  • 修改 JAX 源代码树根目录中的 WORKSPACE 文件以指向不同的 XLA 树。

要将更改贡献回 XLA,请向 XLA 存储库发送 PR。

JAX 固定的 XLA 版本会定期更新,但会在每次 jaxlib 发布之前特别更新。

在 Windows 上从源代码构建 jaxlib 的其他注意事项#

注意:JAX 不支持 Windows 上的 CUDA;请使用 WSL2 获取 CUDA 支持。

在 Windows 上,请按照 安装 Visual Studio 设置 C++ 工具链。需要 Visual Studio 2019 版本 16.5 或更高版本。

JAX 构建使用符号链接,这需要您激活开发人员模式

您可以使用其 Windows 安装程序安装 Python,或者如果您愿意,可以使用 AnacondaMiniconda 设置 Python 环境。

Bazel 的某些目标使用 bash 实用程序进行脚本编写,因此需要 MSYS2。有关详细信息,请参阅在 Windows 上安装 Bazel。安装以下软件包

pacman -S patch coreutils

安装 coreutils 后,realpath 命令应存在于您的 shell 路径中。

安装完所有内容后。打开 PowerShell,并确保 MSYS2 位于当前会话的路径中。确保 bazelpatchrealpath 可访问。激活 conda 环境。

python .\build\build.py build --wheels=jaxlib

要使用调试信息进行构建,请添加标志 --bazel_options='--copt=/Z7'

为 AMD GPU 构建 ROCM jaxlib 的其他注意事项#

有关构建具有 ROCm 支持的 jaxlib 的详细说明,请参阅官方指南:从源代码构建 ROCm JAX

管理密封的 Python#

为了确保 JAX 的构建是可重现的,在支持的平台(Linux、Windows、MacOS)上表现一致,并且与本地系统的具体细节正确隔离,我们依赖于密封的 Python(由 rules_python 提供,有关详细信息,请参阅 工具链注册)用于通过 Bazel 执行的所有构建和测试命令。这意味着您的系统 Python 安装将在构建期间被忽略,并且 Python 解释器本身以及所有 Python 依赖项将由 bazel 直接管理。

指定 Python 版本#

当您运行 build/build.py 工具时,密封的 Python 版本会自动设置为与您用于运行 build/build.py 脚本的 Python 版本匹配。要显式选择特定版本,您可以将 --python_version 参数传递给该工具

python build/build.py build --python_version=3.12

在底层,密封的 Python 版本由 HERMETIC_PYTHON_VERSION 环境变量控制,该变量在您运行 build/build.py 时会自动设置。如果您直接运行 bazel,您可能需要通过以下方式之一显式设置该变量

# Either add an entry to your `.bazelrc` file
build --repo_env=HERMETIC_PYTHON_VERSION=3.12

# OR pass it directly to your specific build command
bazel build <target> --repo_env=HERMETIC_PYTHON_VERSION=3.12

# OR set the environment variable globally in your shell:
export HERMETIC_PYTHON_VERSION=3.12

您可以通过在运行之间简单地切换 --python_version 的值,在同一台计算机上针对不同版本的 Python 顺序运行构建和测试。上一次构建中缓存的所有与 python 无关的部分将被保留并用于后续构建。

指定 Python 依赖项#

在 Bazel 构建期间,所有 JAX 的 Python 依赖项都会被固定到它们的特定版本。这对于确保构建的可重现性是必要的。JAX 依赖项的完整传递闭包的固定版本及其对应的哈希值在 build/requirements_lock_<python version>.txt 文件中指定(例如,build/requirements_lock_3_12.txt 用于 Python 3.12)。

要更新锁定文件,请确保 build/requirements.in 包含所需的直接依赖项列表,然后执行以下命令(这将在后台调用 pip-compile

python build/build.py requirements_update --python_version=3.12

或者,如果您需要更多控制,可以直接运行 Bazel 命令(这两个命令是等效的)

bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.12

其中 3.12 是您希望更新的 Python 版本。

请注意,由于底层仍然使用 pippip-compile 工具,因此这些工具支持的大多数命令行参数和功能也将被 Bazel 依赖项更新器命令识别。例如,如果您希望更新器考虑预发布版本,只需将 --pre 参数传递给 Bazel 命令

bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.12 -- --pre

指定本地 wheel 的依赖项#

默认情况下,构建会在存储库根目录中的 dist 目录中扫描任何本地 .whl 文件,以便将其包含在依赖项列表中。如果 wheel 是 Python 版本特定的,则只会包含与所选 Python 版本匹配的 wheel。

整体本地 wheel 搜索和选择逻辑由 python_init_repositories() 宏(直接从 WORKSPACE 文件中调用)的参数控制。您可以使用 local_wheel_dist_folder 来更改包含本地 wheel 的文件夹的位置。使用 local_wheel_inclusion_listlocal_wheel_exclusion_list 参数来指定应在搜索中包含和/或排除哪些 wheel(它支持基本通配符匹配)。

如果需要,您还可以手动依赖本地 .whl 文件,绕过自动本地 wheel 搜索机制。例如,要依赖您新构建的 jaxlib wheel,您可以在 build/requirements.in 中添加 wheel 的路径,并为所选 Python 版本重新运行依赖项更新器命令。例如

echo -e "\n$(realpath jaxlib-0.4.27.dev20240416-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in
python build/build.py requirements_update --python_version=3.12

指定 nightly wheel 的依赖项#

为了针对最新、可能不稳定的 Python 依赖项集进行构建和测试,我们提供了依赖项更新器命令的特殊版本,如下所示

python build/build.py requirements_update --python_version=3.12 --nightly_update

或者,如果您直接运行 bazel(这两个命令是等效的)

bazel run //build:requirements_nightly.update --repo_env=HERMETIC_PYTHON_VERSION=3.12

此更新器与常规更新器的区别在于,默认情况下它会接受预发布版本、dev 和 nightly 包,它还会搜索 https://pypi.anaconda.org/scientific-python-nightly-wheels/simple 作为额外的索引 URL,并且不会在生成的依赖项锁定文件中放入哈希值。

自定义 hermetic Python(高级用法)#

我们开箱即用地支持所有当前版本的 Python,因此除非您的工作流程有非常特殊的要求(例如能够使用您自己的自定义 Python 解释器),否则您可以安全地完全跳过此部分。

简而言之,如果您依赖于非标准的 Python 工作流程,您仍然可以在 hermetic Python 设置中获得很高的灵活性。从概念上讲,与非 hermetic 情况相比,只会有一个区别:您需要从文件的角度进行思考,而不是安装(即,思考您的构建实际上依赖于哪些文件,而不是需要在您的系统上安装哪些文件),其余的几乎相同。

因此,在实践中,要完全控制您的 Python 环境(hermetic 或非 hermetic),您需要能够执行以下三件事

  1. 指定要使用的 python 解释器(即,选择实际的 pythonpython3 二进制文件以及与该文件位于同一文件夹中的库)。

  2. 指定 Python 依赖项列表(例如 numpy)及其实际版本。

  3. 能够轻松地在列表中添加/删除/更新依赖项。每个依赖项本身也可以是自定义的(例如,自行构建的)。

您已经知道如何在非 hermetic Python 环境中执行上述所有步骤,以下是如何在 hermetic 环境中执行相同操作(从文件的角度而不是安装的角度来处理)

  1. 不要安装 Python,而是获取 tarzip 文件中的 Python 解释器。根据您的情况,您可以简单地拉取许多现有的解释器(例如 python-build-standalone),或者构建您自己的解释器并将其打包到存档中(遵循官方的 构建说明即可)。例如,在 Linux 上,它看起来像以下内容

    ./configure --prefix python
    make -j12
    make altinstall
    tar -czpf my_python.tgz python
    

    准备好 tarball 后,将 HERMETIC_PYTHON_URL 环境变量指向存档(本地存档或来自互联网的存档),将其插入到构建中

    --repo_env=HERMETIC_PYTHON_URL="file:///local/path/to/my_python.tgz"
    --repo_env=HERMETIC_PYTHON_SHA256=<file's_sha256_sum>
    
    # OR
    --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz"
    --repo_env=HERMETIC_PYTHON_SHA256=<file's_sha256_sum>
    
    # We assume that top-level folder in the tarbal is called "python", if it is
    # something different just pass additional HERMETIC_PYTHON_PREFIX parameter
    --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz"
    --repo_env=HERMETIC_PYTHON_SHA256=<file's_sha256_sum>
    --repo_env=HERMETIC_PYTHON_PREFIX="my_python/install"
    
  2. 不要执行 pip install,而是创建 requirements_lock.txt 文件,其中包含您依赖项的完整传递闭包。您还可以依赖此存储库中已签入的现有依赖项(只要它们与您的自定义 Python 版本一起使用)。关于如何执行此操作,没有特殊说明,您可以按照此文档中的 指定 Python 依赖项 中建议的步骤进行操作,只需直接调用 pip-compile 即可(请注意,锁定文件必须是 hermetic 的,但如果您愿意,始终可以从非 hermetic python 生成它)或者甚至手动创建它(请注意,哈希在锁定文件中是可选的)。

  3. 如果您需要更新或自定义依赖项列表,您可以再次按照 指定 Python 依赖项 说明来更新 requirements_lock.txt,直接调用 pip-compile 或手动修改它。如果您有一个要使用的自定义包,只需从您的锁定中直接指向其 .whl 文件(请记住,从文件的角度而不是安装的角度来处理)(请注意,requirements.txtrequirements_lock.txt 文件支持本地 wheel 引用)。如果您的 requirements_lock.txt 已在 WORKSPACE 文件中指定为 python_init_repositories() 的依赖项,则您无需执行其他任何操作。否则,您可以按如下方式指向您的自定义文件

    --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/custom_requirements_lock.txt"
    

    另请注意,如果您使用 HERMETIC_REQUIREMENTS_LOCK,则它完全控制您的依赖项列表,并且 指定本地 wheel 的依赖项 中描述的自动本地 wheel 解析逻辑将被禁用,以避免干扰它。

就是这样。总结一下:如果您有一个包含 Python 解释器的存档以及一个包含依赖项的完整传递闭包的 requirements_lock.txt 文件,那么您就可以完全控制您的 Python 环境。

自定义 hermetic Python 示例#

请注意,对于以下所有示例,您还可以全局设置环境变量(即,在您的 shell 中使用 export,而不是在您的命令中使用 --repo_env 参数),这样通过 build/build.py 调用 bazel 即可正常工作。

使用来自互联网的自定义 Python 3.13 进行构建,使用已在此存储库中签入的默认 requirements_lock_3_13.txt(即,自定义解释器但默认依赖项)

bazel build <target>
  --repo_env=HERMETIC_PYTHON_VERSION=3.13
  --repo_env=HERMETIC_PYTHON_URL="https://github.com/indygreg/python-build-standalone/releases/download/20241016/cpython-3.13.0+20241016-x86_64-unknown-linux-gnu-install_only.tar.gz"
  --repo_env=HERMETIC_PYTHON_SHA256="2c8cb15c6a2caadaa98af51df6fe78a8155b8471cb3dd7b9836038e0d3657fb4"

使用来自本地文件系统的自定义 Python 3.13 和自定义锁定文件进行构建(假设在运行命令之前,锁定文件已放置在此存储库的 jax/build 文件夹中)

bazel test <target>
  --repo_env=HERMETIC_PYTHON_VERSION=3.13
  --repo_env=HERMETIC_PYTHON_URL="file:///path/to/cpython.tar.gz"
  --repo_env=HERMETIC_PYTHON_PREFIX="prefix/to/strip/in/cython/tar/gz/archive"
  --repo_env=HERMETIC_PYTHON_SHA256=<sha256_sum>
  --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/build:custom_requirements_lock.txt"

如果默认 python 解释器对您来说足够好,而您只需要一组自定义的依赖项

bazel test <target>
  --repo_env=HERMETIC_PYTHON_VERSION=3.13
  --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/build:custom_requirements_lock.txt"

请注意,您可以有多个与同一 Python 版本对应的不同的 requirement_lock.txt 文件来支持不同的场景。您可以通过指定 HERMETIC_PYTHON_VERSION 来控制选择哪个文件。例如,在 WORKSPACE 文件中

requirements = {
  "3.10": "//build:requirements_lock_3_10.txt",
  "3.11": "//build:requirements_lock_3_11.txt",
  "3.12": "//build:requirements_lock_3_12.txt",
  "3.13": "//build:requirements_lock_3_13.txt",
  "3.13-scenario1": "//build:scenario1_requirements_lock_3_13.txt",
  "3.13-scenario2": "//build:scenario2_requirements_lock_3_13.txt",
},

然后,您可以在不更改环境中任何内容的情况下,构建和测试不同的内容组合

# To build with scenario1 dependendencies:
bazel test <target> --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1

# To build with scenario2 dependendencies:
bazel test <target> --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario2

# To build with default dependendencies:
bazel test <target> --repo_env=HERMETIC_PYTHON_VERSION=3.13

# To build with scenario1 dependendencies and custom Python 3.13 interpreter:
bazel test <target>
  --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1
  --repo_env=HERMETIC_PYTHON_URL="file:///path/to/cpython.tar.gz"
  --repo_env=HERMETIC_PYTHON_SHA256=<sha256_sum>

安装 jax#

安装 jaxlib 后,您可以通过运行以下命令来安装 jax

pip install -e .  # installs jax

要从 GitHub 升级到最新版本,只需从 JAX 存储库根目录运行 git pull,并通过运行 build.py 或在必要时升级 jaxlib 来重新构建。您不应该重新安装 jax,因为 pip install -e 会将符号链接从 site-packages 设置到存储库中。

运行测试#

有两种受支持的机制来运行 JAX 测试,可以使用 Bazel 或使用 pytest。

使用 Bazel#

首先,使用 --configure_only 标志配置 JAX 构建。对于 CPU 测试,传递 --wheel_list=jaxlib,对于 GPU 测试,传递 CUDA/ROCM

python build/build.py build --wheels=jaxlib --configure_only
python build/build.py build --wheels=jax-cuda-plugin --configure_only
python build/build.py build --wheels=jax-rocm-plugin --configure_only

您可以将其他选项传递给 build.py 以配置构建;有关详细信息,请参阅 jaxlib 构建文档。

默认情况下,Bazel 构建使用从源代码构建的 jaxlib 运行 JAX 测试。要运行 JAX 测试,请运行

bazel test //tests:cpu_tests //tests:backend_independent_tests

如果您有必要的硬件,也可以使用 //tests:gpu_tests//tests:tpu_tests

要使用预安装的 jaxlib 而不是构建它,您首先需要使其在 hermetic Python 中可用。要在 hermetic Python 中安装特定版本的 jaxlib,请运行(以 jaxlib >= 0.4.26 为例)

echo -e "\njaxlib >= 0.4.26" >> build/requirements.in
python build/build.py requirements_update

或者,要从本地 wheel 安装 jaxlib(假设为 Python 3.12)

echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in
python build/build.py requirements_update --python_version=3.12

一旦你以密封方式安装了 jaxlib,请运行

bazel test --//jax:build_jaxlib=false //tests:cpu_tests //tests:backend_independent_tests

可以使用环境变量控制许多测试行为(见下文)。可以使用 Bazel 的 --test_env=FLAG=value 标志将环境变量传递给 JAX 测试。

一些 JAX 测试是针对多个加速器(即 GPU、TPU)的。当 JAX 已经安装时,你可以像这样运行 GPU 测试

bazel test //tests:gpu_tests --local_test_jobs=4 --test_tag_filters=multiaccelerator --//jax:build_jaxlib=false --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform

你可以通过在多个加速器上并行运行单加速器测试来加速它们。这也会触发每个加速器的多个并发测试。对于 GPU,你可以这样做

NB_GPUS=2
JOBS_PER_ACC=4
J=$((NB_GPUS * JOBS_PER_ACC))
MULTI_GPU="--run_under $PWD/build/parallel_accelerator_execute.sh --test_env=JAX_ACCELERATOR_COUNT=${NB_GPUS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_ACC} --local_test_jobs=$J"
bazel test //tests:gpu_tests //tests:backend_independent_tests --test_env=XLA_PYTHON_CLIENT_PREALLOCATE=false --test_tag_filters=-multiaccelerator $MULTI_GPU

使用 pytest#

首先,通过运行 pip install -r build/test-requirements.txt 来安装依赖项。

要使用 pytest 运行所有 JAX 测试,我们建议使用 pytest-xdist,它可以并行运行测试。它作为 pip install -r build/test-requirements.txt 命令的一部分安装。

从仓库根目录运行

pytest -n auto tests

控制测试行为#

JAX 以组合方式生成测试用例,你可以使用 JAX_NUM_GENERATED_CASES 环境变量来控制为每个测试生成和检查的用例数量(默认为 10)。自动化测试当前默认使用 25。

例如,可以这样写

# Bazel
bazel test //tests/... --test_env=JAX_NUM_GENERATED_CASES=25`

# pytest
JAX_NUM_GENERATED_CASES=25 pytest -n auto tests

自动化测试还使用默认的 64 位浮点数和整数 (JAX_ENABLE_X64) 运行测试

JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest -n auto tests

你可以使用 pytest 的内置选择机制运行更具体的测试集,或者你可以直接运行特定的测试文件以查看有关正在运行的用例的更详细信息

JAX_NUM_GENERATED_CASES=5 python tests/lax_numpy_test.py

你可以通过传递环境变量 JAX_SKIP_SLOW_TESTS=1 来跳过一些已知较慢的测试。

要从测试文件中指定要运行的特定测试集,你可以通过 --test_targets 标志传递字符串或正则表达式。例如,你可以使用以下命令运行 jax.numpy.pad 的所有测试

python tests/lax_numpy_test.py --test_targets="testPad"

Colab 笔记本会作为文档构建的一部分进行错误测试。

假设检验#

一些测试使用 hypothesis。通常,hypothesis 将使用多个示例输入进行测试,并且在测试失败时,它会尝试找到一个仍然导致失败的较小示例:在测试失败中查找如下所示的行,并添加消息中提到的装饰器

You can reproduce this example by temporarily adding @reproduce_failure('6.97.4', b'AXicY2DAAAAAEwAB') as a decorator on your test case

对于交互式开发,你可以设置环境变量 JAX_HYPOTHESIS_PROFILE=interactive (或等效的标志 --jax_hypothesis_profile=interactive)以将示例数量设置为 1,并跳过示例最小化阶段。

文档测试#

JAX 在文档测试模式下使用 pytest 来测试文档中的代码示例。你可以在 ci-build.yaml 中找到运行文档测试的最新命令。例如,你可以运行

JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst

此外,JAX 在 doctest-modules 模式下运行 pytest,以确保函数文档字符串中的代码示例可以正确运行。你可以在本地运行此命令,例如

JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest --doctest-modules jax/_src/numpy/lax_numpy.py

类型检查#

我们使用 mypy 来检查类型提示。要使用与 github CI 检查相同的配置运行 mypy,你可以使用 pre-commit 框架

pip install pre-commit
pre-commit run mypy --all-files

由于 mypy 在检查所有文件时可能会有些慢,因此仅检查你修改过的文件可能会更方便。为此,首先暂存更改(即 git add 更改的文件),然后在提交更改之前运行此命令

pre-commit run mypy

代码检查#

JAX 使用 ruff 代码检查器来确保代码质量。要使用与 github CI 检查相同的配置运行 ruff,你可以使用 pre-commit 框架

pip install pre-commit
pre-commit run ruff --all-files

更新文档#

要重建文档,请安装几个软件包

pip install -r docs/requirements.txt

然后运行

sphinx-build -b html docs docs/build/html -j auto

这可能需要很长时间,因为它会执行文档源中的许多笔记本;如果你希望在不执行笔记本的情况下构建文档,可以运行

sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto

然后,你可以在 docs/build/html/index.html 中查看生成的文档。

-j auto 选项控制构建的并行性。你可以使用一个数字代替 auto 来控制要使用的 CPU 核心数。

更新笔记本#

我们使用 jupytext 来维护 docs/notebooks 中笔记本的两个同步副本:一个采用 ipynb 格式,另一个采用 md 格式。前者的优点是可以在 Colab 中直接打开和执行;后者的优点是它可以更容易地跟踪版本控制中的差异。

编辑 ipynb#

对于对代码和输出进行实质性修改的重大更改,最容易在 Jupyter 或 Colab 中编辑笔记本。要在 Colab 界面中编辑笔记本,请打开 http://colab.research.google.com 并从本地仓库 Upload。根据需要更新它,Run all cells 然后 Download ipynb。你可能需要使用上面解释的 sphinx-build 来测试它是否正确执行。

编辑 md#

对于对笔记本的文本内容进行较小的更改,最容易使用文本编辑器编辑 .md 版本。

同步笔记本#

在编辑笔记本的 ipynb 或 md 版本后,你可以通过在更新的笔记本上运行 jupytext --sync 来同步这两个版本;例如

pip install jupytext==1.16.4
jupytext --sync docs/notebooks/thinking_in_jax.ipynb

jupytext 版本应与 .pre-commit-config.yaml 中指定的版本匹配。

要检查 markdown 和 ipynb 文件是否正确同步,你可以使用 pre-commit 框架执行 github CI 使用的相同检查

pip install pre-commit
pre-commit run jupytext --all-files

创建新笔记本#

如果你要向文档添加新的笔记本,并且想使用此处讨论的 jupytext --sync 命令,你可以使用以下命令为 jupytext 设置笔记本

jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb

这通过向笔记本文件添加一个 "jupytext" 元数据字段来实现,该字段指定所需的格式,并且 jupytext --sync 命令在调用时会识别该字段。

Sphinx 构建中的笔记本#

一些笔记本作为预提交检查的一部分和 Read the docs 构建的一部分自动构建。如果单元格引发错误,则构建将失败。如果错误是故意的,你可以捕获它们,或者使用 raises-exceptions 元数据标记单元格(示例 PR)。你必须在 .ipynb 文件中手动添加此元数据。当其他人重新保存笔记本时,它将被保留。

我们从构建中排除了一些笔记本,例如,因为它们包含长时间的计算。请参阅 conf.py 中的 exclude_patterns

readthedocs.io 上构建文档#

JAX 的自动生成文档位于 https://jax.net.cn/

整个项目的文档构建由 readthedocs JAX 设置控制。当前设置会在代码推送到 GitHub 的 main 分支后立即触发文档构建。对于每个代码版本,构建过程由 .readthedocs.ymldocs/conf.py 配置文件驱动。

对于每次自动文档构建,您可以查看文档构建日志

如果您想在 Readthedocs 上测试文档生成,您可以将代码推送到 test-docs 分支。该分支也会自动构建,您可以在此处查看生成的文档。如果文档构建失败,您可能需要清除 test-docs 的构建环境

对于本地测试,我可以在一个全新的目录中,通过重放我在 Readthedocs 日志中看到的命令来实现。

mkvirtualenv jax-docs  # A new virtualenv
mkdir jax-docs  # A new directory
cd jax-docs
git clone --no-single-branch --depth 50 https://github.com/jax-ml/jax
cd jax
git checkout --force origin/test-docs
git clean -d -f -f
workon jax-docs

python -m pip install --upgrade --no-cache-dir pip
python -m pip install --upgrade --no-cache-dir -I Pygments==2.3.1 setuptools==41.0.1 docutils==0.14 mock==1.0.1 pillow==5.4.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.8.1 recommonmark==0.5.0 'sphinx<2' 'sphinx-rtd-theme<0.5' 'readthedocs-sphinx-ext<1.1'
python -m pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
cd docs
python `which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html