从源代码构建#

首先,获取 JAX 源代码

git clone https://github.com/jax-ml/jax
cd jax

构建 JAX 涉及两个步骤

  1. 构建或安装 jaxlib,它是 jax 的 C++ 支持库。

  2. 安装 jax Python 包。

构建或安装 jaxlib#

使用 pip 安装 jaxlib#

如果您只修改 JAX 的 Python 部分,我们建议使用 pip 从预构建的 wheel 包安装 jaxlib

pip install jaxlib

有关 pip 安装的完整指南(例如,对 GPU 和 TPU 的支持),请参阅 JAX 自述文件

从源代码构建 jaxlib#

警告

尽管通常应该可以使用大多数现代编译器从源代码编译 jaxlib,但构建仅使用 clang 进行测试。欢迎提交拉取请求以改进对不同工具链的支持,但我们不主动支持其他编译器。

要从源代码构建 jaxlib,您还必须安装一些先决条件

  • 一个 C++ 编译器

    如上框所示,最好使用最新版本的 clang(撰写本文时,我们测试的版本是 18),但其他编译器(例如 g++ 或 MSVC)也可能适用。

    在 Ubuntu 或 Debian 上,您可以按照 LLVM 文档中的说明安装最新稳定版 clang。

    如果您在 Mac 上构建,请确保已安装 XCode 和 XCode 命令行工具。

    Windows 构建说明请参见下文。

  • Python:用于运行构建辅助脚本。请注意,无需在本地安装 Python 依赖项,因为您的系统 Python 在构建期间将被忽略;详情请查看管理独立 Python

要为 CPU 或 TPU 构建 jaxlib,您可以运行

python build/build.py build --wheels=jaxlib --verbose
pip install dist/*.whl  # installs jaxlib (includes XLA)

要为与当前系统安装的 Python 版本不同的 Python 版本构建 wheel,请将 --python_version 标志传递给构建命令

python build/build.py build --wheels=jaxlib --python_version=3.12 --verbose

本文档的其余部分假设您正在为与当前系统安装的 Python 版本匹配的版本进行构建。如果您需要为不同版本构建,只需每次调用 python build/build.py 时附加 --python_version=<py version> 标志即可。请注意,无论是否传递 --python_version 参数,Bazel 构建都将始终使用独立 Python 安装。

如果您想构建 jaxlib 和 CUDA 插件:运行

python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt

生成三个 wheel 包(不带 cuda 的 jaxlib、jax-cuda-plugin 和 jax-cuda-pjrt)。默认情况下,所有 CUDA 编译步骤都由 NVCC 和 clang 执行,但可以通过 --build_cuda_with_clang 标志将其限制为 clang。

有关配置选项,请参阅 python build/build.py --help。这里的 python 应该是您的 Python 3 解释器的名称;在某些系统上,您可能需要使用 python3。尽管使用 python 调用脚本,但 Bazel 将始终使用其自己的独立 Python 解释器和依赖项,只有 build/build.py 脚本本身将由您的系统 Python 解释器处理。默认情况下,wheel 包会写入当前目录的 dist/ 子目录。

  • JAX v.0.4.32 及更高版本:您可以在配置选项中提供自定义 CUDA 和 CUDNN 版本。Bazel 将下载它们并用作目标依赖项。

    要下载特定版本的 CUDA/CUDNN 分发包,您可以使用 --cuda_version--cudnn_version 标志

    python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 \
    --cudnn_version=9.1.1
    

    python build/build.py build --wheels=jax-cuda-pjrt --cuda_version=12.3.2 \
    --cudnn_version=9.1.1
    

    请注意,这些参数是可选的:默认情况下,Bazel 将下载 .bazelrc 中提供的 CUDA 和 CUDNN 分发版本,分别在环境变量 HERMETIC_CUDA_VERSIONHERMETIC_CUDNN_VERSION 中。

    要指向本地文件系统上的 CUDA/CUDNN/NCCL 分发包,您可以使用以下命令

    python build/build.py build --wheels=jax-cuda-plugin \
    --bazel_options=--repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" \
    --bazel_options=--repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" \
    --bazel_options=--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl"
    

    请参阅 XLA 文档中的完整说明列表。

  • JAX v.0.4.32 以前的版本:您必须安装 CUDA 和 CUDNN,并使用配置选项提供它们的路径。

从源代码构建带有修改版 XLA 仓库的 jaxlib。#

JAX 依赖于 XLA,其源代码位于 XLA GitHub 仓库中。默认情况下,JAX 使用 XLA 仓库的一个固定副本,但我们通常希望在处理 JAX 时使用 XLA 的本地修改副本。有两种方法可以做到这一点

  • 使用 Bazel 的 override_repository 功能,您可以将其作为命令行标志传递给 build.py,如下所示

    python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla
    
  • 修改 JAX 源代码根目录中的 WORKSPACE 文件以指向不同的 XLA 树。

要将更改贡献回 XLA,请向 XLA 仓库发送 PR。

JAX 锁定的 XLA 版本会定期更新,但特别是在每个 jaxlib 版本发布之前会进行更新。

在 Windows 上从源代码构建 jaxlib 的补充说明#

注意:JAX 不支持 Windows 上的 CUDA;请使用 WSL2 获取 CUDA 支持。

在 Windows 上,请按照 安装 Visual Studio 设置 C++ 工具链。需要 Visual Studio 2019 版本 16.5 或更高版本。

JAX 构建使用符号链接,这需要您激活开发人员模式

您可以使用其Windows 安装程序安装 Python,或者,如果您愿意,可以使用 AnacondaMiniconda 设置 Python 环境。

Bazel 的某些目标使用 bash 工具进行脚本编写,因此需要 MSYS2。有关更多详细信息,请参阅 在 Windows 上安装 Bazel。安装以下软件包

pacman -S patch coreutils

安装 coreutils 后,realpath 命令应该存在于您的 shell 路径中。

一切安装完成后。打开 PowerShell,并确保 MSYS2 在当前会话的路径中。确保可以访问 bazelpatchrealpath。激活 conda 环境。

python .\build\build.py build --wheels=jaxlib

要使用调试信息进行构建,请添加标志 --bazel_options='--copt=/Z7'

为 AMD GPU 构建 ROCM jaxlib 的补充说明#

有关构建支持 ROCm 的 jaxlib 的详细说明,请参阅官方指南:从源代码构建 ROCm JAX

管理独立 Python#

为确保 JAX 的构建可重现、在所有受支持的平台(Linux、Windows、MacOS)上行为一致,并与本地系统的具体情况适当隔离,我们依赖于独立 Python(由 rules_python 提供,详见 Toolchain Registration)来执行所有通过 Bazel 执行的构建和测试命令。这意味着您的系统 Python 安装在构建期间将被忽略,Python 解释器本身以及所有 Python 依赖项将直接由 bazel 管理。

指定 Python 版本#

当您运行 build/build.py 工具时,独立 Python 的版本会自动设置为与您用于运行 build/build.py 脚本的 Python 版本匹配。要显式选择特定版本,您可以将 --python_version 参数传递给该工具

python build/build.py build --python_version=3.12

在底层,独立 Python 版本由 HERMETIC_PYTHON_VERSION 环境变量控制,当您运行 build/build.py 时会自动设置该变量。如果您直接运行 bazel,您可能需要通过以下方法之一显式设置该变量

# Either add an entry to your `.bazelrc` file
build --repo_env=HERMETIC_PYTHON_VERSION=3.12

# OR pass it directly to your specific build command
bazel build <target> --repo_env=HERMETIC_PYTHON_VERSION=3.12

# OR set the environment variable globally in your shell:
export HERMETIC_PYTHON_VERSION=3.12

您可以通过在运行之间简单地切换 --python_version 的值,在同一台机器上按顺序针对不同版本的 Python 运行构建和测试。来自先前构建的所有与 python 无关的构建缓存将保留并重用于后续构建。

指定 Python 依赖项#

在 Bazel 构建期间,所有 JAX 的 Python 依赖项都锁定到其特定版本。这是确保构建可重现性所必需的。JAX 依赖项的完整传递闭包及其相应哈希的锁定版本在 build/requirements_lock_<python version>.txt 文件中指定(例如,build/requirements_lock_3_12.txt 对应 Python 3.12)。

要更新锁定文件,请确保 build/requirements.in 包含所需的直接依赖项列表,然后执行以下命令(它将在底层调用 pip-compile

python build/build.py requirements_update --python_version=3.12

或者,如果您需要更多控制,可以直接运行 bazel 命令(这两个命令是等效的)

bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.12

其中 3.12 是您希望更新的 Python 版本。

请注意,由于底层仍使用 pippip-compile 工具,因此这些工具支持的大多数命令行参数和功能也将被 Bazel 依赖项更新命令识别。例如,如果您希望更新程序考虑预发布版本,只需将 --pre 参数传递给 bazel 命令

bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.12 -- --pre

指定本地 wheel 依赖项#

默认情况下,构建会扫描仓库根目录中的 dist 目录,以查找要包含在依赖项列表中的任何本地 .whl 文件。如果 wheel 是特定于 Python 版本的,则只会包含与所选 Python 版本匹配的 wheel。

整体本地 wheel 搜索和选择逻辑由 python_init_repositories() 宏(直接从 WORKSPACE 文件调用)的参数控制。您可以使用 local_wheel_dist_folder 更改本地 wheel 文件夹的位置。使用 local_wheel_inclusion_listlocal_wheel_exclusion_list 参数来指定应包含和/或排除哪些 wheel(它支持基本的通配符匹配)。

如有必要,您也可以手动依赖本地 .whl 文件,绕过自动本地 wheel 搜索机制。例如,要依赖您新构建的 jaxlib wheel,您可以在 build/requirements.in 中添加 wheel 的路径,然后为所选的 Python 版本重新运行依赖项更新命令。例如

echo -e "\n$(realpath jaxlib-0.4.27.dev20240416-cp312-cp312-manylinux_2_27_x86_64.whl)" >> build/requirements.in
python build/build.py requirements_update --python_version=3.12

指定每晚构建的 wheel 依赖项#

为了针对最新、可能不稳定的 Python 依赖项集进行构建和测试,我们提供了特殊版本的依赖项更新命令,如下所示

python build/build.py requirements_update --python_version=3.12 --nightly_update

或者,如果您直接运行 bazel(这两个命令是等效的)

bazel run //build:requirements_nightly.update --repo_env=HERMETIC_PYTHON_VERSION=3.12

此更新程序与常规更新程序的区别在于,默认情况下它将接受预发布、开发和每晚构建的软件包,它还将搜索 https://pypi.anaconda.org/scientific-python-nightly-wheels/simple 作为额外的索引 URL,并且不会将哈希放入生成的 requirements 锁定文件中。

自定义独立 Python(高级用法)#

我们开箱即用地支持所有当前版本的 Python,因此除非您的工作流程有非常特殊的要求(例如能够使用您自己的自定义 Python 解释器),否则您可以完全跳过本节。

简而言之,如果您依赖非标准的 Python 工作流程,您仍然可以在独立 Python 设置中实现高度的灵活性。从概念上讲,与非独立情况相比,只有一个区别:您需要以文件的形式思考,而不是安装(即,思考您的构建实际依赖哪些文件,而不是您的系统需要安装哪些文件),其余部分大致相同。

因此,实际上,要完全控制您的 Python 环境(无论是独立还是非独立),您需要能够做以下三件事

  1. 指定要使用的 python 解释器(即选择实际的 pythonpython3 二进制文件以及同一文件夹中随附的库)。

  2. 指定 Python 依赖项列表(例如 numpy)及其实际版本。

  3. 能够轻松添加/删除/更新列表中的依赖项。每个依赖项本身也可以是自定义的(例如,自构建的)。

您已经知道如何在非独立 Python 环境中执行上述所有步骤,以下是在独立环境中执行相同操作的方法(通过以文件的形式思考,而不是安装)

  1. 不是安装 Python,而是以 tarzip 文件的形式获取 Python 解释器。根据您的情况,您可以简单地拉取许多现有解释器之一(例如 python-build-standalone),或者构建您自己的并将其打包到存档中(遵循官方的构建说明即可)。例如,在 Linux 上,它看起来像以下内容

    ./configure --prefix python
    make -j12
    make altinstall
    tar -czpf my_python.tgz python
    

    一旦 tarball 准备好,通过将 HERMETIC_PYTHON_URL 环境变量指向存档(无论是本地存档还是来自互联网)来将其插入构建中

    --repo_env=HERMETIC_PYTHON_URL="file:///local/path/to/my_python.tgz"
    --repo_env=HERMETIC_PYTHON_SHA256=<file's_sha256_sum>
    
    # OR
    --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz"
    --repo_env=HERMETIC_PYTHON_SHA256=<file's_sha256_sum>
    
    # We assume that top-level folder in the tarball is called "python", if it is
    # something different just pass additional HERMETIC_PYTHON_PREFIX parameter
    --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz"
    --repo_env=HERMETIC_PYTHON_SHA256=<file's_sha256_sum>
    --repo_env=HERMETIC_PYTHON_PREFIX="my_python/install"
    
  2. 不是执行 pip install,而是创建 requirements_lock.txt 文件,其中包含您的依赖项的完整传递闭包。您也可以依赖此仓库中已检查的现有依赖项(只要它们与您的自定义 Python 版本兼容)。关于如何操作没有特殊说明,您可以按照本文档中指定 Python 依赖项中推荐的步骤操作,只需直接调用 pip-compile(注意,锁定文件必须是独立的,但如果您愿意,可以始终从非独立 Python 生成它)甚至手动创建它(注意,锁定文件中的哈希是可选的)。

  3. 如果您需要更新或自定义您的依赖项列表,您可以再次按照指定 Python 依赖项的说明更新 requirements_lock.txt,直接调用 pip-compile 或手动修改它。如果您有要使用的自定义包,只需在您的锁定文件中直接指向其 .whl 文件(记住,以文件的形式工作,而不是安装)(注意,requirements.txtrequirements_lock.txt 文件支持本地 wheel 引用)。如果您的 requirements_lock.txt 已在 WORKSPACE 文件中指定为 python_init_repositories() 的依赖项,则无需执行其他任何操作。否则,您可以按如下方式指向您的自定义文件

    --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/custom_requirements_lock.txt"
    

    另请注意,如果您使用 HERMETIC_REQUIREMENTS_LOCK,那么它将完全控制您的依赖项列表,并且指定本地 wheel 依赖项中描述的自动本地 wheel 解析逻辑将被禁用,以免干扰它。

就是这样。总结一下:如果您有一个包含 Python 解释器的存档和一个包含您的依赖项的完整传递闭包的 requirements_lock.txt 文件,那么您就完全控制了您的 Python 环境。

自定义独立 Python 示例#

请注意,对于以下所有示例,您也可以全局设置环境变量(即在 shell 中 export,而不是在命令中使用 --repo_env 参数),这样通过 build/build.py 调用 bazel 就能正常工作。

使用来自互联网的自定义 Python 3.13 进行构建,使用此仓库中已检查的默认 requirements_lock_3_13.txt(即自定义解释器但默认依赖项)

bazel build <target>
  --repo_env=HERMETIC_PYTHON_VERSION=3.13
  --repo_env=HERMETIC_PYTHON_URL="https://github.com/indygreg/python-build-standalone/releases/download/20241016/cpython-3.13.0+20241016-x86_64-unknown-linux-gnu-install_only.tar.gz"
  --repo_env=HERMETIC_PYTHON_SHA256="2c8cb15c6a2caadaa98af51df6fe78a8155b8471cb3dd7b9836038e0d3657fb4"

使用本地文件系统中的自定义 Python 3.13 和自定义锁定文件进行构建(假设锁定文件在运行命令前已放入此仓库的 jax/build 文件夹中)

bazel test <target>
  --repo_env=HERMETIC_PYTHON_VERSION=3.13
  --repo_env=HERMETIC_PYTHON_URL="file:///path/to/cpython.tar.gz"
  --repo_env=HERMETIC_PYTHON_PREFIX="prefix/to/strip/in/cython/tar/gz/archive"
  --repo_env=HERMETIC_PYTHON_SHA256=<sha256_sum>
  --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/build:custom_requirements_lock.txt"

如果默认 python 解释器对您来说足够好,而您只需要一组自定义的依赖项

bazel test <target>
  --repo_env=HERMETIC_PYTHON_VERSION=3.13
  --repo_env=HERMETIC_REQUIREMENTS_LOCK="/absolute/path/to/build:custom_requirements_lock.txt"

请注意,您可以有多个不同的 requirement_lock.txt 文件,对应于相同的 Python 版本以支持不同的场景。您可以通过指定 HERMETIC_PYTHON_VERSION 来控制选择哪个文件。例如,在 WORKSPACE 文件中

requirements = {
  "3.11": "//build:requirements_lock_3_11.txt",
  "3.12": "//build:requirements_lock_3_12.txt",
  "3.13": "//build:requirements_lock_3_13.txt",
  "3.13-scenario1": "//build:scenario1_requirements_lock_3_13.txt",
  "3.13-scenario2": "//build:scenario2_requirements_lock_3_13.txt",
},

然后您可以在不更改任何环境的情况下构建和测试不同组合的内容

# To build with scenario1 dependencies:
bazel test <target> --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1

# To build with scenario2 dependencies:
bazel test <target> --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario2

# To build with default dependencies:
bazel test <target> --repo_env=HERMETIC_PYTHON_VERSION=3.13

# To build with scenario1 dependencies and custom Python 3.13 interpreter:
bazel test <target>
  --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1
  --repo_env=HERMETIC_PYTHON_URL="file:///path/to/cpython.tar.gz"
  --repo_env=HERMETIC_PYTHON_SHA256=<sha256_sum>

安装 jax#

安装 jaxlib 后,您可以通过运行以下命令安装 jax

pip install -e .  # installs jax

要升级到 GitHub 上的最新版本,只需从 JAX 仓库根目录运行 git pull,并根据需要通过运行 build.py 或升级 jaxlib 来重新构建。您无需重新安装 jax,因为 pip install -e 会从 site-packages 到仓库设置符号链接。

运行测试#

有两种支持的机制可以运行 JAX 测试,要么使用 Bazel,要么使用 pytest。

使用 Bazel#

首先,使用 --configure_only 标志配置 JAX 构建。对于 CPU 测试,传递 --wheel_list=jaxlib;对于 GPU 测试,传递 CUDA/ROCM

python build/build.py build --wheels=jaxlib --configure_only
python build/build.py build --wheels=jax-cuda-plugin --configure_only
python build/build.py build --wheels=jax-rocm-plugin --configure_only

您可以将其他选项传递给 build.py 以配置构建;详细信息请参阅 jaxlib 构建文档。

默认情况下,Bazel 构建使用从源代码构建的 jaxlib 运行 JAX 测试。要运行 JAX 测试,请运行

bazel test //tests:cpu_tests //tests:backend_independent_tests

如果您有必要的硬件,//tests:gpu_tests//tests:tpu_tests 也可用。

您需要配置 cuda 才能运行 gpu 测试

python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only

要使用预安装的 jaxlib 而不是构建它,您首先需要使其在独立 Python 中可用。要在独立 Python 中安装特定版本的 jaxlib,请运行(以 jaxlib >= 0.4.26 为例)

echo -e "\njaxlib >= 0.4.26" >> build/requirements.in
python build/build.py requirements_update

或者,要从本地 wheel 安装 jaxlib(假设 Python 3.12)

echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux_2_27_x86_64.whl)" >> build/requirements.in
python build/build.py requirements_update --python_version=3.12

一旦您独立安装了 jaxlib,请运行

bazel test --//jax:build_jaxlib=false //tests:cpu_tests //tests:backend_independent_tests

许多测试行为可以使用环境变量进行控制(见下文)。环境变量可以通过 Bazel 的 --test_env=FLAG=value 标志传递给 JAX 测试。

JAX 的一些测试是针对多个加速器(即 GPU、TPU)的。当 JAX 已安装时,您可以这样运行 GPU 测试

bazel test //tests:gpu_tests --local_test_jobs=4 --test_tag_filters=multiaccelerator --//jax:build_jaxlib=false --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform

您可以通过在多个加速器上并行运行来加速单个加速器测试。这也会触发每个加速器的多个并发测试。对于 GPU,您可以这样做

NB_GPUS=2
JOBS_PER_ACC=4
J=$((NB_GPUS * JOBS_PER_ACC))
MULTI_GPU="--run_under $PWD/build/parallel_accelerator_execute.sh --test_env=JAX_ACCELERATOR_COUNT=${NB_GPUS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_ACC} --local_test_jobs=$J"
bazel test //tests:gpu_tests //tests:backend_independent_tests --test_env=XLA_PYTHON_CLIENT_PREALLOCATE=false --test_tag_filters=-multiaccelerator $MULTI_GPU

使用 pytest#

首先,通过运行 pip install -r build/test-requirements.txt 安装依赖项。

要使用 pytest 运行所有 JAX 测试,我们建议使用 pytest-xdist,它可以在并行运行测试。它是作为 pip install -r build/test-requirements.txt 命令的一部分安装的。

从仓库根目录运行

pytest -n auto tests

控制测试行为#

JAX 组合地生成测试用例,您可以使用 JAX_NUM_GENERATED_CASES 环境变量控制为每个测试生成和检查的用例数量(默认值为 10)。自动化测试目前默认使用 25。

例如,可以这样写

# Bazel
bazel test //tests/... --test_env=JAX_NUM_GENERATED_CASES=25`

# pytest
JAX_NUM_GENERATED_CASES=25 pytest -n auto tests

自动化测试还会运行默认使用 64 位浮点数和整数的测试(JAX_ENABLE_X64

JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest -n auto tests

您可以使用 pytest 的内置选择机制运行更具体的测试集,或者,您可以直接运行特定的测试文件以查看有关正在运行的用例的更详细信息

JAX_NUM_GENERATED_CASES=5 python tests/lax_numpy_test.py

您可以跳过一些已知较慢的测试,通过传递环境变量 JAX_SKIP_SLOW_TESTS=1。

要从测试文件中指定一组特定的测试来运行,您可以通过 --test_targets 标志传递字符串或正则表达式。例如,您可以使用以下命令运行 jax.numpy.pad 的所有测试

python tests/lax_numpy_test.py --test_targets="testPad"

Colab Notebook 作为文档构建的一部分进行错误测试。

Hypothesis 测试#

一些测试使用 hypothesis。通常,hypothesis 会使用多个示例输入进行测试,并且在测试失败时,它会尝试找到一个更小的、仍然导致失败的示例:查看测试失败信息中类似下面一行,并添加消息中提到的装饰器

You can reproduce this example by temporarily adding @reproduce_failure('6.97.4', b'AXicY2DAAAAAEwAB') as a decorator on your test case

对于交互式开发,您可以设置环境变量 JAX_HYPOTHESIS_PROFILE=interactive(或等效的标志 --jax_hypothesis_profile=interactive),以便将示例数量设置为 1,并跳过示例最小化阶段。

Doctest 测试#

JAX 使用 pytest 在 doctest 模式下测试文档中的代码示例。您可以在 ci-build.yaml 中找到运行 doctest 的最新命令。例如,您可以运行

JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md

此外,JAX 在 doctest-modules 模式下运行 pytest,以确保函数文档字符串中的代码示例能正确运行。您可以在本地运行此命令,例如

JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest --doctest-modules jax/_src/numpy/lax_numpy.py

类型检查#

我们使用 mypy 检查类型提示。要以与 github CI 检查相同的配置运行 mypy,您可以使用 pre-commit 框架

pip install pre-commit
pre-commit run mypy --all-files

由于 mypy 在检查所有文件时可能会很慢,因此只检查您修改过的文件可能会很方便。为此,首先暂存更改(即 git add 更改的文件),然后在提交更改之前运行此命令

pre-commit run mypy

代码风格检查#

JAX 使用 ruff linter 来确保代码质量。要以与 github CI 检查相同的配置运行 ruff,您可以使用 pre-commit 框架

pip install pre-commit
pre-commit run ruff --all-files

更新文档#

要重新构建文档,请安装几个软件包

pip install -r docs/requirements.txt

然后运行

sphinx-build -b html docs docs/build/html -j auto

这可能需要很长时间,因为它会执行文档源中的许多 Notebook;如果您希望在不执行 Notebook 的情况下构建文档,可以运行

sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto

然后您可以在 docs/build/html/index.html 中查看生成的文档。

-j auto 选项控制构建的并行度。您可以用一个数字替换 auto 来控制使用多少个 CPU 核心。

更新 Notebook#

我们使用 jupytextdocs/notebooks 中维护 Notebook 的两个同步副本:一个采用 ipynb 格式,另一个采用 md 格式。前者的优点是它可以直接在 Colab 中打开和执行;后者的优点是它使版本控制中的差异跟踪变得容易得多。

编辑 ipynb 文件#

对于进行大幅度修改代码和输出的更改,在 Jupyter 或 Colab 中编辑 Notebook 最容易。要在 Colab 界面中编辑 Notebook,请打开 http://colab.research.google.com 并从本地仓库 Upload。根据需要更新,Run all cells,然后 Download ipynb。您可能需要测试它是否正确执行,如上文所述使用 sphinx-build

编辑 md 文件#

对于对 Notebook 文本内容进行较小的更改,使用文本编辑器编辑 .md 版本最容易。

同步 Notebook#

编辑完 Notebook 的 ipynb 或 md 版本后,您可以使用 jupytext 同步这两个版本,方法是在更新的 Notebook 上运行 jupytext --sync;例如

pip install jupytext==1.16.4
jupytext --sync docs/notebooks/thinking_in_jax.ipynb

jupytext 版本应与 .pre-commit-config.yaml 中指定的版本匹配。

要检查 Markdown 和 ipynb 文件是否正确同步,您可以使用 pre-commit 框架执行与 github CI 相同的检查

pip install pre-commit
pre-commit run jupytext --all-files

创建新 Notebook#

如果您要向文档添加新的 Notebook,并且希望使用此处讨论的 jupytext --sync 命令,您可以使用以下命令为 Notebook 设置 jupytext

jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb

此操作通过向 Notebook 文件添加一个 "jupytext" 元数据字段来实现,该字段指定了所需的格式,并且在调用时 jupytext --sync 命令会识别此字段。

Sphinx 构建中的 Notebook#

一些 Notebook 作为预提交检查和 Read the docs 构建的一部分自动构建。如果单元格引发错误,构建将失败。如果错误是故意的,您可以捕获它们,或者使用 raises-exceptions 元数据标记单元格(示例 PR)。您必须在 .ipynb 文件中手动添加此元数据。当其他人重新保存 Notebook 时,它将被保留。

我们排除了某些 Notebook 的构建,例如,因为它们包含长时间的计算。请参阅 conf.py 中的 exclude_patterns

readthedocs.io 上的文档构建#

JAX 自动生成的文档位于 https://jax.net.cn/

文档构建由 readthedocs JAX 设置控制整个项目。当前设置在代码推送到 GitHub main 分支后立即触发文档构建。对于每个代码版本,构建过程由 .readthedocs.ymldocs/conf.py 配置文件驱动。

对于每次自动化文档构建,您都可以查看文档构建日志

如果您想在 Readthedocs 上测试文档生成,可以将代码推送到 test-docs 分支。该分支也会自动构建,您可以在此处查看生成的文档。如果文档构建失败,您可能需要清除 test-docs 的构建环境

为了进行本地测试,我可以在一个新目录中通过重放我在 Readthedocs 日志中看到的命令来完成

mkvirtualenv jax-docs  # A new virtualenv
mkdir jax-docs  # A new directory
cd jax-docs
git clone --no-single-branch --depth 50 https://github.com/jax-ml/jax
cd jax
git checkout --force origin/test-docs
git clean -d -f -f
workon jax-docs

python -m pip install --upgrade --no-cache-dir pip
python -m pip install --upgrade --no-cache-dir -I Pygments==2.3.1 setuptools==41.0.1 docutils==0.14 mock==1.0.1 pillow==5.4.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.8.1 recommonmark==0.5.0 'sphinx<2' 'sphinx-rtd-theme<0.5' 'readthedocs-sphinx-ext<1.1'
python -m pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
cd docs
python `which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html