配置选项

目录

配置选项#

JAX 提供了各种配置选项来自定义其行为。这些选项控制从数值精度到调试功能的方方面面。

如何使用配置选项#

JAX 配置选项可以通过多种方式设置:

  1. 环境变量(在运行程序前设置)

    export JAX_ENABLE_X64=True
    python my_program.py
    
  2. 运行时配置(在 Python 代码中)

    import jax
    jax.config.update("jax_enable_x64", True)
    
  3. 命令行标志(使用 Abseil)

    # In your code:
    import jax
    jax.config.parse_flags_with_absl()
    
    # When running:
    python my_program.py --jax_enable_x64=True
    

常用配置选项#

以下是一些最常用的配置选项:

  • jax_enable_x64 – 启用 64 位浮点精度

  • jax_disable_jit – 禁用 JIT 编译以进行调试

  • jax_debug_nans – 检查 NaN 并引发错误

  • jax_platforms – 控制 JAX 将初始化哪些后端(CPU/GPU/TPU)

  • jax_numpy_rank_promotion – 控制自动秩提升行为

  • jax_default_matmul_precision – 设置矩阵乘法运算的默认精度

所有配置选项#

以下是所有可用 JAX 配置选项的完整列表:

允许 Ragged Dot TPU 显式转置#

类型:

bool

默认值:

True

配置字符串:

'allow_ragged_dot_tpu_explicit_transpose'

环境变量:

ALLOW_RAGGED_DOT_TPU_EXPLICIT_TRANSPOSE

是否在 TPU 上为 ragged_dot 使用 Python 显式转置。

检查 Vma#

类型:

bool

默认值:

False

配置字符串:

'check_vma'

环境变量:

CHECK_VMA

shard_map 的内部实现细节,请勿使用

即时常量折叠#

类型:

bool

默认值:

False

配置字符串:

'eager_constant_folding'

环境变量:

EAGER_CONSTANT_FOLDING

尝试在暂存期间进行常量折叠。

Jax2Tf 关联扫描归约#

类型:

bool

默认值:

False

配置字符串:

'jax2tf_associative_scan_reductions'

环境变量:

JAX2TF_ASSOCIATIVE_SCAN_REDUCTIONS

JAX 对累积归约原语(cumsum, cumprod, cummax, cummin)有两种独立的降级规则。在 CPU 和 GPU 上使用 lax.associative_scan,而对于 TPU 则使用 HLO ReduceWindow。后者在 CPU 和 GPU 上实现较慢。默认情况下,jax2tf 使用 TPU 降级。将此标志设为 True 以使用关联扫描降级,且仅在它对您的应用程序有影响时使用。有关更多详细信息,请参阅 jax2tf 的 README.md。

Jax2Tf 默认原生序列化#

类型:

bool

默认值:

True

配置字符串:

'jax2tf_default_native_serialization'

环境变量:

JAX2TF_DEFAULT_NATIVE_SERIALIZATION

设置 jax2tf.convert 中 native_serialization 参数的默认值。建议直接使用参数而不是该标志,该标志将来可能会被移除。从 JAX 0.4.31 开始,非原生序列化已被弃用。

允许 F16 归约#

类型:

bool

默认值:

True

配置字符串:

'jax_allow_f16_reductions'

环境变量:

JAX_ALLOW_F16_REDUCTIONS

如果为 False,对 f16bf16 输入进行 reduce_sum 将引发错误。默认为 True。

数组垃圾回收防护#

类型:

枚举值: 'allow', 'log', 'fatal'

默认值:

配置字符串:

'jax_array_garbage_collection_guard'

环境变量:

JAX_ARRAY_GARBAGE_COLLECTION_GUARD

jax.Array 对象选择垃圾回收防护级别。

此选项可用于控制当 jax.Array 对象被垃圾回收时发生的情况。理想情况下,jax.Array 对象应由 Python 引用计数释放,而不是通过垃圾回收,以避免数组占用设备内存直到垃圾回收发生。

有效值为:

  • allow: 不记录 jax.Array 对象的垃圾回收。

  • log: 当 jax.Array 被垃圾回收时记录错误。

  • fatal: 如果 jax.Array 被垃圾回收,则引发致命错误。

默认为 allow。注意,并非所有循环引用都能被检测到。

后端目标#

类型:

str

默认值:

''

配置字符串:

'jax_backend_target'

环境变量:

JAX_BACKEND_TARGET

可以是 “local” 或 “rpc:address”,用于连接远程服务目标。

Bcoo Cusparse 降级#

类型:

bool

默认值:

False

配置字符串:

'jax_bcoo_cusparse_lowering'

环境变量:

JAX_BCOO_CUSPARSE_LOWERING

启用 BCOO 算子到 cuSparse 的降级。

捕获常量报告帧数#

类型:

int

默认值:

0

配置字符串:

'jax_captured_constants_report_frames'

环境变量:

JAX_CAPTURED_CONSTANTS_REPORT_FRAMES

报告每个捕获常量时显示的堆栈帧数,以指示常量在何处被捕获(文件及操作)。设为 -1 打印所有帧,设为 0 禁用。注意:仅当捕获常量的总数超过 jax_captured_constants_warn_bytes 时才会生成报告,因为生成报告开销较大。

捕获常量警告字节数#

类型:

int

默认值:

2000000000

配置字符串:

'jax_captured_constants_warn_bytes'

环境变量:

JAX_CAPTURED_CONSTANTS_WARN_BYTES

在发出警告之前,参数可以作为常量捕获的字节数。默认为大约 2GB。设为 -1 禁用警告。

检查代理环境变量#

类型:

bool

默认值:

True

配置字符串:

'jax_check_proxy_envs'

环境变量:

JAX_CHECK_PROXY_ENVS

检查用户环境中的代理变量并发出警告。

检查静态索引#

类型:

bool

默认值:

False

配置字符串:

'jax_check_static_indices'

环境变量:

JAX_CHECK_STATIC_INDICES

在数组索引运算期间开启静态索引的边界检查。这仅在索引模式为 PROMISE_IN_BOUNDS 时生效,这是 gather 类型操作的默认模式。

检查追踪器泄漏#

类型:

bool

默认值:

False

配置字符串:

'jax_check_tracer_leaks'

环境变量:

JAX_CHECK_TRACER_LEAKS

在跟踪完成时立即开始检查跟踪器泄漏。启用泄漏检查可能会影响性能:一些缓存将被禁用,并可能增加其他开销。此外,请注意,某些 Python 调试器可能导致误报,因此建议在启用泄漏检查时禁用任何调试器。

编译缓存检查内容#

类型:

bool

默认值:

False

配置字符串:

'jax_compilation_cache_check_contents'

环境变量:

JAX_COMPILATION_CACHE_CHECK_CONTENTS

启用编译缓存时,检查磁盘缓存中找到的值是否与全新编译的结果匹配。此检查仅在进程中首次遇到该键时执行。

编译缓存目录#

类型:

str

默认值:

配置字符串:

'jax_compilation_cache_dir'

环境变量:

JAX_COMPILATION_CACHE_DIR

缓存路径。优先级:1. 调用 compilation_cache.set_cache_dir()。2. 命令行设置的值或默认值。

编译缓存期望 PGLE#

类型:

bool

默认值:

False

配置字符串:

'jax_compilation_cache_expect_pgle'

环境变量:

JAX_COMPILATION_CACHE_EXPECT_PGLE

如果设为 True,即使当前未启用 PGLE,也会优先加载使用分析数据编译的编译缓存条目(即启用了 PGLE 并进行了指定次数的分析执行)。当未找到首选缓存条目时,将打印警告。

编译缓存键包含元数据#

类型:

bool

默认值:

False

配置字符串:

'jax_compilation_cache_include_metadata_in_key'

环境变量:

JAX_COMPILATION_CACHE_INCLUDE_METADATA_IN_KEY

在编译缓存键中包含元数据(如文件名和行号)。如果为 false,即使函数或文件被移动,缓存仍能命中。但这意味着从缓存加载的可执行文件可能具有陈旧的元数据,这可能会出现在分析(profiles)等中。

编译缓存最大尺寸#

类型:

int

默认值:

-1

配置字符串:

'jax_compilation_cache_max_size'

环境变量:

JAX_COMPILATION_CACHE_MAX_SIZE

持久编译缓存允许的最大字节数。设置后,当缓存目录总大小超过限制时,将删除最近最少访问的缓存条目。若设为 0,缓存将被禁用。特殊值 -1 表示无限制,缓存大小可无限增长。

编译器详细日志最小算子数#

类型:

int

默认值:

10

配置字符串:

'jax_compiler_detailed_logging_min_ops'

环境变量:

JAX_COMPILER_DETAILED_LOGGING_MIN_OPS

在 JAX 启用详细编译器日志记录之前,模块应有多大(以 MLIR 运算计)?此标志的目的是抑制小型/无趣计算的详细日志记录。

编译器启用重算(Remat)通道#

类型:

bool

默认值:

True

配置字符串:

'jax_compiler_enable_remat_pass'

环境变量:

JAX_COMPILER_ENABLE_REMAT_PASS

启用/禁用重算 HLO 通道的配置。在遇到 OOM 错误时,有助于让 XLA 自动权衡内存和计算。然而,您可能通过手动使用 jax.checkpoint 获得更好的结果。

CPU 集合通信实现#

类型:

枚举值: 'gloo', 'mpi', 'megascale'

默认值:

'gloo'

配置字符串:

'jax_cpu_collectives_implementation'

环境变量:

JAX_CPU_COLLECTIVES_IMPLEMENTATION

CPU 上使用的跨进程集合通信实现。必须是 (“gloo”, “mpi”) 之一。

CPU 启用异步调度#

类型:

bool

默认值:

True

配置字符串:

'jax_cpu_enable_async_dispatch'

环境变量:

JAX_CPU_ENABLE_ASYNC_DISPATCH

仅适用于非并行计算。如果为 False,则不使用异步调度,直接运行计算。

CPU 获取全局拓扑超时分钟数#

类型:

int

默认值:

5

配置字符串:

'jax_cpu_get_global_topology_timeout_minutes'

环境变量:

JAX_CPU_GET_GLOBAL_TOPOLOGY_TIMEOUT_MINUTES

获取 CPU 设备全局拓扑的超时分钟数;必须严格大于 –jax_cpu_get_local_topology_timeout_minutes

CPU 获取本地拓扑超时分钟数#

类型:

int

默认值:

2

配置字符串:

'jax_cpu_get_local_topology_timeout_minutes'

环境变量:

JAX_CPU_GET_LOCAL_TOPOLOGY_TIMEOUT_MINUTES

在构建全局拓扑时获取每个 CPU 设备本地拓扑的超时分钟数。

跨主机传输 Socket 地址#

类型:

str

默认值:

''

配置字符串:

'jax_cross_host_transfer_socket_address'

环境变量:

JAX_CROSS_HOST_TRANSFER_SOCKET_ADDRESS

通过 DCN 进行跨主机设备传输时使用的 Socket 地址。仅在 PjRt 插件不支持跨主机传输时才需要。

跨主机传输超时秒数#

类型:

int

默认值:

配置字符串:

'jax_cross_host_transfer_timeout_seconds'

环境变量:

JAX_CROSS_HOST_TRANSFER_TIMEOUT_SECONDS

通过 KV 存储进行跨主机传输元数据交换的超时时间。默认值为一分钟。

跨主机传输传输大小#

类型:

int

默认值:

配置字符串:

'jax_cross_host_transfer_transfer_size'

环境变量:

JAX_CROSS_HOST_TRANSFER_TRANSFER_SIZE

分块传输请求的分块大小。

跨主机传输地址#

类型:

str

默认值:

''

配置字符串:

'jax_cross_host_transport_addresses'

环境变量:

JAX_CROSS_HOST_TRANSPORT_ADDRESSES

用于通过 DCN 进行跨主机设备传输的以逗号分隔的传输地址列表。如果未设置,默认为 [0.0.0.0:0] * 4。

CUDA 可见设备#

类型:

str

默认值:

'all'

配置字符串:

'jax_cuda_visible_devices'

环境变量:

JAX_CUDA_VISIBLE_DEVICES

限制 JAX 将使用的 CUDA 设备集合。可以是 “all”,或者是逗号分隔的整数设备 ID 列表。

自定义 Vjp3#

类型:

bool

默认值:

False

配置字符串:

'jax_custom_vjp3'

环境变量:

JAX_CUSTOM_VJP3

如果为 True,拥抱自定义自动微分规则的未来。这将在未来的 JAX 版本中默认启用,届时所有对该标志的使用都将被视为已弃用(遵循 API 兼容性策略)。

调试 Inf#

类型:

bool

默认值:

False

配置字符串:

'jax_debug_infs'

环境变量:

JAX_DEBUG_INFS

为每个操作添加 Inf 检查。当在 JIT 编译计算的输出中检测到 Inf 时,调用非编译版本,以便更精确地识别生成该 Inf 的操作。

调试密钥重用#

类型:

bool

默认值:

False

配置字符串:

'jax_debug_key_reuse'

环境变量:

JAX_DEBUG_KEY_REUSE

开启实验性密钥重用检查。启用此配置后,将跟踪类型化 PRNG 密钥(即使用 jax.random.key() 创建的密钥)的使用情况,对已使用密钥的不正确重用将导致错误。目前启用此功能会导致每次调用带有密钥作为输入或输出的 JIT 编译函数时产生轻微的 Python 开销。

调试日志模块#

类型:

str

默认值:

''

配置字符串:

'jax_debug_log_modules'

环境变量:

JAX_DEBUG_LOG_MODULES

以逗号分隔的模块名称列表(例如 “jax” 或 “jax._src.xla_bridge,jax._src.dispatch”),用于启用调试日志记录。

调试 NaN#

类型:

bool

默认值:

False

配置字符串:

'jax_debug_nans'

环境变量:

JAX_DEBUG_NANS

为每个操作添加 NaN 检查。当在 JIT 编译计算的输出中检测到 NaN 时,调用非编译版本,以便更精确地识别生成该 NaN 的操作。

默认设备#

类型:

str

默认值:

配置字符串:

'jax_default_device'

环境变量:

JAX_DEFAULT_DEVICE

配置 JAX 操作的默认设备。设置为 Device 对象(例如 jax.devices("cpu")[0])以将该设备用作 JAX 操作和 JIT 编译函数调用的默认设备(对多设备计算,如 pmapped 函数调用,无影响)。设置为 None 则使用系统默认设备。

默认矩阵乘法精度#

类型:

枚举值: 'default', 'high', 'highest', 'bfloat16', 'tensorfloat32', 'float32', 'ANY_F8_ANY_F8_F32', 'ANY_F8_ANY_F8_F32_FAST_ACCUM', 'ANY_F8_ANY_F8_ANY', 'ANY_F8_ANY_F8_ANY_FAST_ACCUM', 'F16_F16_F16', 'F16_F16_F32', 'BF16_BF16_BF16', 'BF16_BF16_F32', 'BF16_BF16_F32_X3', 'BF16_BF16_F32_X6', 'BF16_BF16_F32_X9', 'TF32_TF32_F32', 'TF32_TF32_F32_X3', 'F32_F32_F32', 'F64_F64_F64'

默认值:

配置字符串:

'jax_default_matmul_precision'

环境变量:

JAX_DEFAULT_MATMUL_PRECISION

控制 32 位输入的默认矩阵乘法和卷积精度。

某些平台(如 TPU)为矩阵乘法和卷积计算提供可配置的精度级别,以速度换取精度。精度可以针对每个操作进行控制;例如,参见 jax.lax.conv_general_dilated()jax.lax.dot() 的文档字符串。但对于未指定特定精度的操作,控制其默认行为非常有用。

此选项可用于控制 32 位输入上涉及矩阵乘法和卷积的计算的默认精度级别。这些级别大致描述了标量积的计算精度。‘bfloat16’ 选项最快且精度最低;‘float32’ 类似于完全 float32 精度;‘tensorfloat32’ 则介于两者之间。

此参数还可用于指定执行矩阵乘法(如 jax.lax.dot())的函数的累积“算法”。要指定算法,请将此选项设置为 DotAlgorithmPreset 的名称。

默认伪随机数生成器实现#

类型:

枚举值: 'threefry2x32', 'rbg', 'unsafe_rbg'

默认值:

'threefry2x32'

配置字符串:

'jax_default_prng_impl'

环境变量:

JAX_DEFAULT_PRNG_IMPL

选择默认的 PRNG 实现,当在生成随机种子时未明确提供实现时使用。

禁用反向传播检查#

类型:

bool

默认值:

False

配置字符串:

'jax_disable_bwd_checks'

环境变量:

JAX_DISABLE_BWD_CHECKS

禁用所有反向传播检查。这将在未来的 JAX 版本中默认启用,届时所有对该标志的使用都将被视为已弃用(遵循 API 兼容性策略)。

禁用 JIT#

类型:

bool

默认值:

False

配置字符串:

'jax_disable_jit'

环境变量:

JAX_DISABLE_JIT

禁用 JIT 编译并直接调用原始 Python 代码。

禁用大部分优化#

类型:

bool

默认值:

False

配置字符串:

'jax_disable_most_optimizations'

环境变量:

JAX_DISABLE_MOST_OPTIMIZATIONS

禁用 Vmap Shmap 错误#

类型:

bool

默认值:

False

配置字符串:

'jax_disable_vmap_shmap_error'

环境变量:

JAX_DISABLE_VMAP_SHMAP_ERROR

禁用 vmap-of-shmap 中错误检查的临时补救措施。

禁止 Mesh 上下文管理器#

类型:

bool

默认值:

False

配置字符串:

'jax_disallow_mesh_context_manager'

环境变量:

JAX_DISALLOW_MESH_CONTEXT_MANAGER

如果设为 True,尝试将 Mesh 用作上下文管理器将导致 RuntimeError。

分布式调试#

类型:

bool

默认值:

False

配置字符串:

'jax_distributed_debug'

环境变量:

JAX_DISTRIBUTED_DEBUG

启用对调试多进程分布式计算有用的日志记录。日志通过 logging 在 WARNING 级别执行。

导出 IR 模式#

类型:

str

默认值:

'stablehlo'

配置字符串:

'jax_dump_ir_modes'

环境变量:

JAX_DUMP_IR_MODES

以逗号分隔的模式,用于转储 IR。可以是 ‘stablehlo’(默认)、‘jaxpr’ 或用于 jaxpr 方程计数 pprof 配置文件的 ‘eqn_count_pprof’。

导出 IR 路径#

类型:

str

默认值:

''

配置字符串:

'jax_dump_ir_to'

环境变量:

JAX_DUMP_IR_TO

JAX 发出的 IR 应转储为文本文件的路径。如果省略,JAX 将不会转储任何 IR。支持特殊值 ‘sponge’,从环境变量 TEST_UNDECLARED_OUTPUTS_DIR 获取路径。有关转储内容的选项,请参见 jax_dump_ir_modes。

启用检查#

类型:

bool

默认值:

False

配置字符串:

'jax_enable_checks'

环境变量:

JAX_ENABLE_CHECKS

开启 JAX 内部的不变量检查。会使程序变慢。

启用编译缓存#

类型:

bool

默认值:

True

配置字符串:

'jax_enable_compilation_cache'

环境变量:

JAX_ENABLE_COMPILATION_CACHE

如果设为 False,无论是否调用 set_cache_dir(),编译缓存都将被禁用。如果设为 True,路径可以设置为默认值或通过调用 set_cache_dir() 设置。

启用自定义伪随机数生成器#

类型:

bool

默认值:

False

配置字符串:

'jax_enable_custom_prng'

环境变量:

JAX_ENABLE_CUSTOM_PRNG

启用一项内部升级,允许定义自定义伪随机数生成器实现。这将在未来的 JAX 版本中默认启用,届时所有对该标志的使用都将被视为已弃用(遵循 API 兼容性策略)。

启用 PGLE#

类型:

bool

默认值:

False

配置字符串:

'jax_enable_pgle'

环境变量:

JAX_ENABLE_PGLE

如果设为 True 且属性 jax_pgle_profiling_runs 设为大于 0,则在运行指定次数后,将使用收集到的数据提供给分析引导的延迟估算器,并重新编译模块。

启用抢占服务#

类型:

bool

默认值:

True

配置字符串:

'jax_enable_preemption_service'

环境变量:

JAX_ENABLE_PREEMPTION_SERVICE

启用抢占服务。有关详细信息,请参阅 multihost_utils.reached_preemption_sync_point。

启用可恢复性#

类型:

bool

默认值:

False

配置字符串:

'jax_enable_recoverability'

环境变量:

JAX_ENABLE_RECOVERABILITY

允许在部分任务失败后,多控制器 JAX 作业继续运行。

启用 X64#

类型:

bool

默认值:

False

配置字符串:

'jax_enable_x64'

环境变量:

JAX_ENABLE_X64

启用 64 位类型的使用。

除法错误检查行为#

类型:

枚举值: 'ignore', 'raise'

默认值:

'ignore'

配置字符串:

'jax_error_checking_behavior_divide'

环境变量:

JAX_ERROR_CHECKING_BEHAVIOR_DIVIDE

指定遇到除以零时的行为。选项为 “ignore” 或 “raise”。

NaN 错误检查行为#

类型:

枚举值: 'ignore', 'raise'

默认值:

'ignore'

配置字符串:

'jax_error_checking_behavior_nan'

环境变量:

JAX_ERROR_CHECKING_BEHAVIOR_NAN

指定遇到 NaN 时的行为。选项为 “ignore” 或 “raise”。

越界错误检查行为#

类型:

枚举值: 'ignore', 'raise'

默认值:

'ignore'

配置字符串:

'jax_error_checking_behavior_oob'

环境变量:

JAX_ERROR_CHECKING_BEHAVIOR_OOB

指定遇到越界访问时的行为。选项为 “ignore” 或 “raise”。

执行时间优化力度#

类型:

float

默认值:

0.0

配置字符串:

'jax_exec_time_optimization_effort'

环境变量:

JAX_EXEC_TIME_OPTIMIZATION_EFFORT

最小化执行时间的力度(越高表示力度越大),有效范围 [-1.0, 1.0]。

实验性不安全 XLA 运行时错误#

类型:

bool

默认值:

False

配置字符串:

'jax_experimental_unsafe_xla_runtime_errors'

环境变量:

JAX_EXPERIMENTAL_UNSAFE_XLA_RUNTIME_ERRORS

在 CPU 和 GPU 上为 jax.experimental.checkify.checks 启用 XLA 运行时错误。这些错误是异步的,可能会丢失且不易读。但是,它们会使计算崩溃,并允许您编写可 JIT 的检查,而无需使用 checkify。在 pmap/pjit 下不工作。

解释缓存未命中#

类型:

bool

默认值:

False

配置字符串:

'jax_explain_cache_misses'

环境变量:

JAX_EXPLAIN_CACHE_MISSES

每当主要缓存(例如追踪缓存)发生未命中时,记录解释。日志记录通过 logging 执行。设置此选项时,日志级别为 WARNING;否则级别为 DEBUG。

显式 X64 数据类型#

类型:

枚举值: WARN, ERROR, ALLOW

默认值:

<ExplicitX64Mode.WARN: 1>

配置字符串:

'jax_explicit_x64_dtypes'

环境变量:

JAX_EXPLICIT_X64_DTYPES

如果设为 ALLOW,即使 enable_x64 为 false,也会尊重显式指定的 64 位类型。如果设为 WARN,将发出警告;如果设为 ERROR,将引发错误。

导出调用约定版本#

类型:

int

默认值:

10

配置字符串:

'jax_export_calling_convention_version'

环境变量:

JAX_EXPORT_CALLING_CONVENTION_VERSION

用于导出的调用约定版本号。这必须在您的部署环境中使用的 tf.XlaCallModule 所支持的版本范围内。参见 https://jax.net.cn/en/latest/export/shape_poly.html#calling-convention-versions

导出忽略前向兼容性#

类型:

bool

默认值:

False

配置字符串:

'jax_export_ignore_forward_compatibility'

环境变量:

JAX_EXPORT_IGNORE_FORWARD_COMPATIBILITY

是否忽略前向兼容性降级规则。参见 https://jax.net.cn/en/latest/export/export.html#compatibility-guarantees-for-custom-calls

强制 DCN 跨主机传输#

类型:

bool

默认值:

False

配置字符串:

'jax_force_dcn_cross_host_transfers'

环境变量:

JAX_FORCE_DCN_CROSS_HOST_TRANSFERS

即使插件支持跨主机传输,也强制跨主机传输使用 DCN socket 传输库。

高动态范围 Gumbel#

类型:

bool

默认值:

False

配置字符串:

'jax_high_dynamic_range_gumbel'

环境变量:

JAX_HIGH_DYNAMIC_RANGE_GUMBEL

如果为 True,Gumbel 噪声会抽取两个样本,以更高的精度覆盖低概率事件。

Hlo 源文件规范化正则表达式#

类型:

str

默认值:

配置字符串:

'jax_hlo_source_file_canonicalization_regex'

环境变量:

JAX_HLO_SOURCE_FILE_CANONICALIZATION_REGEX

用于通过移除给定的正则表达式来规范化 HLO 指令的 source_path 元数据。如果设置,将在每个 source_file 上调用 re.sub(),所有匹配项都将被移除。这可用于在使用持久编译缓存时避免虚假的缓存未命中,因为缓存键中包含了 HLO 元数据。

转储中包含调试信息#

类型:

bool

默认值:

True

配置字符串:

'jax_include_debug_info_in_dumps'

环境变量:

JAX_INCLUDE_DEBUG_INFO_IN_DUMPS

确定在转储 IR 代码时是否保留调试符号和位置信息。默认情况下,调试信息将保留在 IR 转储中。为避免暴露源代码和潜在的敏感信息,请设为 false。

位置信息中包含完整回溯#

类型:

bool

默认值:

True

配置字符串:

'jax_include_full_tracebacks_in_locations'

环境变量:

JAX_INCLUDE_FULL_TRACEBACKS_IN_LOCATIONS

在 JAX 发出的 IR 的 MLIR 位置中包含 Python 回溯。

遗留伪随机数密钥#

类型:

枚举值: ALLOW, WARN, ERROR

默认值:

<LegacyPrngKeyState.ALLOW: 'allow'>

配置字符串:

'jax_legacy_prng_key'

环境变量:

JAX_LEGACY_PRNG_KEY

指定将原始 PRNG 密钥传递给 jax.random API 时的行为。

记录检查点残差#

类型:

bool

默认值:

False

配置字符串:

'jax_log_checkpoint_residuals'

环境变量:

JAX_LOG_CHECKPOINT_RESIDUALS

每次部分评估 jax.checkpoint(即 jax.remat,例如在自动微分中)时记录一条消息,打印保存了哪些残差。

记录编译过程#

类型:

bool

默认值:

False

配置字符串:

'jax_log_compiles'

环境变量:

JAX_LOG_COMPILES

每次 jitpmap 编译 XLA 计算时记录一条消息。日志通过 logging 执行。设置此选项时,日志级别为 WARNING;否则级别为 DEBUG。

日志级别#

类型:

枚举值: 'NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'

默认值:

'NOTSET'

配置字符串:

'jax_logging_level'

环境变量:

JAX_LOGGING_LEVEL

在所有 JAX 日志记录器上设置相应的日志级别。仅接受 [“NOTSET”, “DEBUG”, “INFO”, “WARNING”, “ERROR”, “CRITICAL”] 中的字符串值。如果为 None,则不会设置日志级别。包含 C++ 日志记录。

内存拟合力度#

类型:

float

默认值:

0.0

配置字符串:

'jax_memory_fitting_effort'

环境变量:

JAX_MEMORY_FITTING_EFFORT

最小化内存使用的力度(越高表示力度越大),有效范围 [-1.0, 1.0]。

内存拟合级别#

类型:

枚举值: 'UNKNOWN', 'O0', 'O1', 'O2', 'O3'

默认值:

'O2'

配置字符串:

'jax_memory_fitting_level'

环境变量:

JAX_MEMORY_FITTING_LEVEL

编译器应尝试使程序适配内存的程度。

模拟 GPU 拓扑#

类型:

str

默认值:

''

配置字符串:

'jax_mock_gpu_topology'

环境变量:

JAX_MOCK_GPU_TOPOLOGY

在 GPU 客户端中模拟多主机 GPU 拓扑。该值应采用 “<切片数> x <每个切片的宿主机数> x <每个宿主机的设备数>” 的形式。空字符串关闭模拟。

Mosaic 允许 HLO#

类型:

bool

默认值:

False

配置字符串:

'jax_mosaic_allow_hlo'

环境变量:

JAX_MOSAIC_ALLOW_HLO

允许在 Mosaic 中使用 HLO 方言。

可变数组检查#

类型:

bool

默认值:

True

配置字符串:

'jax_mutable_array_checks'

环境变量:

JAX_MUTABLE_ARRAY_CHECKS

启用排除别名的可变数组错误检查。这将在未来的 JAX 版本中默认启用,届时所有对该标志的使用都将被视为已弃用(遵循 API 兼容性策略)。

禁止执行#

类型:

bool

默认值:

False

配置字符串:

'jax_no_execution'

环境变量:

JAX_NO_EXECUTION

禁止 JAX 执行。

禁止追踪#

类型:

bool

默认值:

False

配置字符串:

'jax_no_tracing'

环境变量:

JAX_NO_TRACING

禁止 JIT 编译的追踪。

CPU 设备数量#

类型:

int

默认值:

-1

配置字符串:

'jax_num_cpu_devices'

环境变量:

JAX_NUM_CPU_DEVICES

使用的 CPU 设备数量。如果未提供,将使用 XLA 标志 –xla_force_host_platform_device_count 的值。必须在 JAX 初始化前设置。

NumPy 数据类型提升#

类型:

枚举值: STANDARD, STRICT

默认值:

<NumpyDtypePromotion.STANDARD: 'standard'>

配置字符串:

'jax_numpy_dtype_promotion'

环境变量:

JAX_NUMPY_DTYPE_PROMOTION

指定在数组之间运算时隐式类型提升所使用的规则。选项为 “standard” 或 “strict”;在 strict 模式下,不同强指定数据类型的数组之间的二元运算将导致错误。

NumPy 秩提升#

类型:

枚举值: 'allow', 'warn', 'raise'

默认值:

'allow'

配置字符串:

'jax_numpy_rank_promotion'

环境变量:

JAX_NUMPY_RANK_PROMOTION

控制 NumPy 风格的自动秩提升广播(“allow”、“warn” 或 “raise”)。

优化级别#

类型:

枚举值: 'UNKNOWN', 'O0', 'O1', 'O2', 'O3'

默认值:

'UNKNOWN'

配置字符串:

'jax_optimization_level'

环境变量:

JAX_OPTIMIZATION_LEVEL

编译器应针对执行时间优化的程度。

Pallas 启用调试检查#

类型:

bool

默认值:

False

配置字符串:

'jax_pallas_enable_debug_checks'

环境变量:

JAX_PALLAS_ENABLE_DEBUG_CHECKS

如果设置,pl.debug_check 调用将在运行时被检查。否则,它们为 noop(空操作)。

Pallas 使用 Mosaic GPU#

类型:

bool

默认值:

True

配置字符串:

'jax_pallas_use_mosaic_gpu'

环境变量:

JAX_PALLAS_USE_MOSAIC_GPU

如果为 True,将 Pallas 内核降级到实验性的 Mosaic GPU 方言,而不是 Triton IR。

Pallas 详细错误信息#

类型:

bool

默认值:

False

配置字符串:

'jax_pallas_verbose_errors'

环境变量:

JAX_PALLAS_VERBOSE_ERRORS

如果为 True,打印 Pallas 内核的详细错误信息。

持久缓存启用 XLA 缓存#

类型:

str

默认值:

'xla_gpu_per_fusion_autotune_cache_dir'

配置字符串:

'jax_persistent_cache_enable_xla_caches'

环境变量:

JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES

启用持久缓存时,将自动启用额外的 XLA 缓存。此选项可用于配置启用哪些 XLA 缓存方法。

持久缓存最小编译时间秒数#

类型:

float

默认值:

1.0

配置字符串:

'jax_persistent_cache_min_compile_time_secs'

环境变量:

JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS

计算写入持久编译缓存的最小编译时间。可以提高此阈值以减少写入缓存的条目数量。

持久缓存最小条目大小字节数#

类型:

int

默认值:

0

配置字符串:

'jax_persistent_cache_min_entry_size_bytes'

环境变量:

JAX_PERSISTENT_CACHE_MIN_ENTRY_SIZE_BYTES

将缓存在持久编译缓存中的条目的最小大小(以字节为单位): * -1: 禁用大小限制并防止覆盖。 * 保持默认 (0) 以允许覆盖。覆盖通常会确保最小大小对正在使用的文件系统是最优的。 * > 0: 所需的实际最小大小;不覆盖。

PGLE 聚合百分位数#

类型:

int

默认值:

90

配置字符串:

'jax_pgle_aggregation_percentile'

环境变量:

JAX_PGLE_AGGREGATION_PERCENTILE

使用 PGLE 时,在设备间聚合性能数据所使用的百分位数。

PGLE 分析运行次数#

类型:

int

默认值:

3

配置字符串:

'jax_pgle_profiling_runs'

环境变量:

JAX_PGLE_PROFILING_RUNS

使用 PGLE 时,重新编译前模块应分析的次数。

PjRt 客户端创建选项#

类型:

str

默认值:

配置字符串:

'jax_pjrt_client_create_options'

环境变量:

JAX_PJRT_CLIENT_CREATE_OPTIONS

提供给设备平台 PjRt 客户端作为额外参数的 “k1:v1;k2:v2” 格式的一组键值对字符串。

平台名称#

类型:

str

默认值:

''

配置字符串:

'jax_platform_name'

环境变量:

JAX_PLATFORM_NAME

已弃用,请使用 –jax_platforms 代替。

平台#

类型:

str

默认值:

配置字符串:

'jax_platforms'

环境变量:

JAX_PLATFORMS

以逗号分隔的平台名称列表,指定 JAX 应初始化哪些平台。如果此列表中的任何平台未能成功初始化,将引发异常且程序将终止。列表中的第一个平台将作为默认平台。例如,config.jax_platforms=cpu,tpu 表示将初始化 CPU 和 TPU 后端,除非另有说明,否则将使用 CPU 后端。如果 TPU 初始化失败,它将引发异常。默认情况下,JAX 将尝试初始化所有可用平台,并默认使用 GPU 或 TPU(如果可用),否则回退到 CPU。

格式化打印使用颜色#

类型:

bool

默认值:

True

配置字符串:

'jax_pprint_use_color'

环境变量:

JAX_PPRINT_USE_COLOR

启用带有丰富语法高亮的 Jaxpr 格式化打印。

Ragged Dot 使用 GPU Pallas Triton 降级#

类型:

bool

默认值:

False

配置字符串:

'jax_ragged_dot_use_gpu_pallas_triton_lowering'

环境变量:

JAX_RAGGED_DOT_USE_GPU_PALLAS_TRITON_LOWERING

(仅限 GPU)如果为 True,则使用 Pallas Triton 降级进行 ragged_dot() 降级。否则,依赖于 ragged_dot_general_p 的默认降级规则。

Ragged Dot 使用 Ragged Dot 指令#

类型:

bool

默认值:

True

配置字符串:

'jax_ragged_dot_use_ragged_dot_instruction'

环境变量:

JAX_RAGGED_DOT_USE_RAGGED_DOT_INSTRUCTION

(仅限 TPU)如果为 True,则使用 chlo.ragged_dot 指令进行 ragged_dot() 降级。否则,依赖于 ragged_dot_general_p 降级规则中的推广逻辑。

引发持久缓存错误#

类型:

bool

默认值:

False

配置字符串:

'jax_raise_persistent_cache_errors'

环境变量:

JAX_RAISE_PERSISTENT_CACHE_ERRORS

如果为 true,读取或写入持久编译缓存时引发的异常将被允许通过,如果未手动捕获,则会中止程序执行。如果为 false,异常会被捕获并作为警告引发,从而允许程序继续执行。默认为 false,以便缓存错误或间歇性问题不会导致致命后果。

随机种子偏移量#

类型:

int

默认值:

0

配置字符串:

'jax_random_seed_offset'

环境变量:

JAX_RANDOM_SEED_OFFSET

所有随机种子的偏移量(例如 jax.random.key() 的参数)。

引用到固定缓冲区#

类型:

bool

默认值:

False

配置字符串:

'jax_refs_to_pins'

环境变量:

JAX_REFS_TO_PINS

将 HLO 中的引用降级为固定缓冲区。这将在未来的 JAX 版本中默认启用,届时所有对该标志的使用都将被视为已弃用(遵循 API 兼容性策略)。

缓存键中移除自定义分区指针#

类型:

bool

默认值:

False

配置字符串:

'jax_remove_custom_partitioning_ptr_from_cache_key'

环境变量:

JAX_REMOVE_CUSTOM_PARTITIONING_PTR_FROM_CACHE_KEY

如果设为 True,在计算缓存键期间哈希处理前,从预编译的 stableHLO 中移除自定义分区指针。这是一个潜在不安全的标志,只有确切知道自己在做什么的用户才应设置它。

类型中移除大小为 1 的 Mesh 轴#

类型:

bool

默认值:

False

配置字符串:

'jax_remove_size_one_mesh_axis_from_type'

环境变量:

JAX_REMOVE_SIZE_ONE_MESH_AXIS_FROM_TYPE

从 ShapedArray.sharding 和 vma 中移除大小为 1 的 Mesh 轴。

ROCM 可见设备#

类型:

str

默认值:

'all'

配置字符串:

'jax_rocm_visible_devices'

环境变量:

JAX_ROCM_VISIBLE_DEVICES

限制 JAX 将使用的 ROCM 设备集合。可以是 “all”,或者是逗号分隔的整数设备 ID 列表。

Scan3#

类型:

bool

默认值:

False

配置字符串:

'jax_scan3'

环境变量:

JAX_SCAN3

如果为 True,拥抱循环的未来。这将在未来的 JAX 版本中默认启用,届时所有对该标志的使用都将被视为已弃用(遵循 API 兼容性策略)。

发送回溯到运行时#

类型:

枚举值: OFF, ON, FULL

默认值:

<RuntimeTracebackMode.OFF: 'off'>

配置字符串:

'jax_send_traceback_to_runtime'

环境变量:

JAX_SEND_TRACEBACK_TO_RUNTIME

控制在调度时发送到运行时的 Python 回溯信息级别: - “OFF”:(默认)不发送 Python 回溯信息。 - “ON”: 仅发送最近的用户帧调用位置。 - “FULL”: 发送调用位置的完整 Python 回溯。这在调度路径上有很高的固定成本,仅应在调试时使用。

主机间共享二进制文件#

类型:

bool

默认值:

False

配置字符串:

'jax_share_binary_between_hosts'

环境变量:

JAX_SHARE_BINARY_BETWEEN_HOSTS

如果设为 True,已编译模块将直接在主机间共享。

主机间共享二进制文件超时毫秒数#

类型:

int

默认值:

1200000

配置字符串:

'jax_share_binary_between_hosts_timeout_ms'

环境变量:

JAX_SHARE_BINARY_BETWEEN_HOSTS_TIMEOUT_MS

已编译模块共享的超时时间。

Softmax 自定义 Jvp#

类型:

bool

默认值:

False

配置字符串:

'jax_softmax_custom_jvp'

环境变量:

JAX_SOFTMAX_CUSTOM_JVP

为 jax.nn.softmax 使用新的 custom_jvp 规则。新规则应能改善内存使用和稳定性。设为 True 以使用新行为。参见 jax-ml/jax#15677。这将在未来的 JAX 版本中默认启用,届时所有对该标志的使用都将被视为已弃用(遵循 API 兼容性策略)。

按进程索引对设备排序#

类型:

bool

默认值:

True

配置字符串:

'jax_sort_devices_by_process_index'

环境变量:

JAX_SORT_DEVICES_BY_PROCESS_INDEX

先按进程索引对 JAX 设备排序,再按设备 ID 排序。如果为 False,仅按设备 ID 排序,这保留了 PJRT 客户端分配的全局设备顺序。

线程防护#

类型:

bool

默认值:

False

配置字符串:

'jax_thread_guard'

环境变量:

JAX_THREAD_GUARD

如果为 True,当从非设置线程防护的线程调用多进程 JAX 操作时,将在运行时引发错误。这对于检测线程可能以不同顺序在不同进程中调度操作导致非确定性崩溃的情况非常有用。

Threefry GPU 内核降级#

类型:

bool

默认值:

False

配置字符串:

'jax_threefry_gpu_kernel_lowering'

环境变量:

JAX_THREEFRY_GPU_KERNEL_LOWERING

在 GPU 上,将 Threefry PRNG 操作降级为内核实现。这会以潜在的运行时内存成本换取更快的编译时间。

Threefry 可分区化#

类型:

bool

默认值:

True

配置字符串:

'jax_threefry_partitionable'

环境变量:

JAX_THREEFRY_PARTITIONABLE

启用内部 Threefry PRNG 实现更改,使其在某些情况下自动可分区。如果没有此标志,使用标准的 jax.random 伪随机数生成可能会导致额外的通信和/或冗余的分布式计算。使用此标志,在某些情况下通信开销会消失。这将在未来的 JAX 版本中默认启用,届时所有对该标志的使用都将被视为已弃用(遵循 API 兼容性策略)。

回溯过滤#

类型:

枚举值: 'off', 'tracebackhide', 'remove_frames', 'quiet_remove_frames', 'auto'

默认值:

'auto'

配置字符串:

'jax_traceback_filtering'

环境变量:

JAX_TRACEBACK_FILTERING

控制 JAX 如何从回溯中过滤掉内部帧。有效值为: - off: 禁用回溯过滤。 - auto: 如果在较新的 IPython 下运行,则使用 tracebackhide,否则使用 remove_frames。 - tracebackhide: 为隐藏的堆栈帧添加 __tracebackhide__ 注释,部分回溯打印程序支持此功能。 - remove_frames: 从回溯中删除隐藏的帧,并将未过滤的回溯添加为异常的 __cause__。 - quiet_remove_frames: 从回溯中删除隐藏的帧,并添加一条简短消息(至异常的 __cause__)说明已执行此操作。

位置信息中的回溯限制#

类型:

int

默认值:

10

配置字符串:

'jax_traceback_in_locations_limit'

环境变量:

JAX_TRACEBACK_IN_LOCATIONS_LIMIT

限制 MLIR 位置中包含的 Python 回溯帧数。如果设为负值,将不限制回溯。

追踪器错误回溯帧数#

类型:

int

默认值:

5

配置字符串:

'jax_tracer_error_num_traceback_frames'

环境变量:

JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES

设置 JAX 追踪器错误消息中的堆栈帧数。

传输防护#

类型:

枚举值: 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'

默认值:

配置字符串:

'jax_transfer_guard'

环境变量:

JAX_TRANSFER_GUARD

选择所有传输的传输防护级别。此选项为只写;特定方向的传输防护级别应使用每传输方向选项读取。默认为 “allow”。

设备到设备传输防护#

类型:

枚举值: 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'

默认值:

配置字符串:

'jax_transfer_guard_device_to_device'

环境变量:

JAX_TRANSFER_GUARD_DEVICE_TO_DEVICE

选择设备到设备传输的传输防护级别。默认为 “allow”。

设备到主机传输防护#

类型:

枚举值: 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'

默认值:

配置字符串:

'jax_transfer_guard_device_to_host'

环境变量:

JAX_TRANSFER_GUARD_DEVICE_TO_HOST

选择设备到主机传输的传输防护级别。默认为 “allow”。

主机到设备传输防护#

类型:

枚举值: 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'

默认值:

配置字符串:

'jax_transfer_guard_host_to_device'

环境变量:

JAX_TRANSFER_GUARD_HOST_TO_DEVICE

选择主机到设备传输的传输防护级别。默认为 “allow”。

使用直接线性化#

类型:

bool

默认值:

True

配置字符串:

'jax_use_direct_linearize'

环境变量:

JAX_USE_DIRECT_LINEARIZE

使用直接线性化,而不是 JVP 后跟部分评估。

使用 Magma#

类型:

枚举值: 'off', 'on', 'auto'

默认值:

'auto'

配置字符串:

'jax_use_magma'

环境变量:

JAX_USE_MAGMA

在 GPU 上启用对 MAGMA 支持的 lax.linalg.eig 的实验性支持。有关如何使用此功能的详细信息,请参阅 lax.linalg.eig 的文档。

使用 Shardy 分区器#

类型:

bool

默认值:

True

配置字符串:

'jax_use_shardy_partitioner'

环境变量:

JAX_USE_SHARDY_PARTITIONER

是否降级(lower)至 Shardy。更多信息请参阅迁移指南:https://jax.net.cn/en/latest/shardy_jax_migration.html。此功能在未来的 JAX 版本中将默认启用,届时所有对此标志的使用将被视为已弃用(遵循 API 兼容性策略)。

使用简化的 Jaxpr 常量#

类型:

bool

默认值:

False

配置字符串:

'jax_use_simplified_jaxpr_constants'

环境变量:

JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS

启用对 Jaxpr 中闭包常量处理的简化。设为 True 可启用此新行为。此标志仅在过渡期短暂存在。请参阅 jax-ml/jax#29679.DO —— 请勿依赖此标志。

Xla 后端#

类型:

str

默认值:

''

配置字符串:

'jax_xla_backend'

环境变量:

JAX_XLA_BACKEND

已弃用,请使用 –jax_platforms 代替。

Xla 配置文件版本#

类型:

int

默认值:

0

配置字符串:

'jax_xla_profile_version'

环境变量:

JAX_XLA_PROFILE_VERSION

用于 XLA 编译的可选配置文件版本。仅当 XLA 配置为支持远程编译配置文件功能时,此选项才有意义。

模拟 GPU 进程数量#

类型:

int

默认值:

0

配置字符串:

'mock_num_gpu_processes'

环境变量:

MOCK_NUM_GPU_PROCESSES

在 GPU 客户端中模拟 JAX 进程的数量。设为 0 可关闭模拟。