配置选项#

JAX 提供了各种配置选项以自定义其行为。这些选项控制着从数值精度到调试功能的一切。

如何使用配置选项#

JAX 配置选项可以通过以下几种方式设置:

  1. 环境变量(在运行程序前设置)

    export JAX_ENABLE_X64=True
    python my_program.py
    
  2. 运行时配置(在您的 Python 代码中)

    import jax
    jax.config.update("jax_enable_x64", True)
    
  3. 命令行标志(使用 Abseil)

    # In your code:
    import jax
    jax.config.parse_flags_with_absl()
    
    # When running:
    python my_program.py --jax_enable_x64=True
    

常用配置选项#

以下是一些最常用的配置选项:

  • jax_enable_x64 – 启用 64 位浮点精度

  • jax_disable_jit – 禁用 JIT 编译以进行调试

  • jax_debug_nans – 检查并报告 NaNs 错误

  • jax_platforms – 控制 JAX 初始化哪些后端(CPU/GPU/TPU)

  • jax_numpy_rank_promotion – 控制自动秩提升行为

  • jax_default_matmul_precision – 设置矩阵乘法操作的默认精度

所有配置选项#

以下是所有可用的 JAX 配置选项的完整列表:

检查 Vma#

类型:

布尔值

默认值:

False

配置字符串:

'check_vma'

环境变量:

CHECK_VMA

shard_map 的内部实现细节,请勿使用

急切常量折叠#

类型:

布尔值

默认值:

False

配置字符串:

'eager_constant_folding'

环境变量:

EAGER_CONSTANT_FOLDING

在暂存期间尝试常量折叠。

Jax2Tf 关联扫描约简#

类型:

布尔值

默认值:

False

配置字符串:

'jax2tf_associative_scan_reductions'

环境变量:

JAX2TF_ASSOCIATIVE_SCAN_REDUCTIONS

JAX 为累积约简原语(cumsum、cumprod、cummax、cummin)提供了两种独立的降级规则。在 CPU 和 GPU 上,它使用 lax.associative_scan,而在 TPU 上,它使用 HLO ReduceWindow。后者在 CPU 和 GPU 上的实现速度较慢。默认情况下,jax2tf 使用 TPU 降级。将此标志设置为 True 以使用关联扫描降级用法,仅当它对您的应用程序有影响时才这样做。有关更多详细信息,请参阅 jax2tf README.md。

Jax2Tf 默认原生序列化#

类型:

布尔值

默认值:

True

配置字符串:

'jax2tf_default_native_serialization'

环境变量:

JAX2TF_DEFAULT_NATIVE_SERIALIZATION

将 native_serialization 参数的默认值设置为 jax2tf.convert。建议使用参数而不是标志,标志将来可能会被移除。从 JAX 0.4.31 开始,非原生序列化已被弃用。

数组垃圾回收保护#

类型:

枚举值: 'allow', 'log', 'fatal'

默认值:

配置字符串:

'jax_array_garbage_collection_guard'

环境变量:

JAX_ARRAY_GARBAGE_COLLECTION_GUARD

jax.Array 对象选择垃圾回收保护级别。

此选项可用于控制当 jax.Array 对象被垃圾回收时发生的情况。理想情况下,jax.Array 对象应通过 Python 引用计数而不是垃圾回收来释放,以避免设备内存被数组占用直到垃圾回收发生。

有效值包括:

  • allow: 不记录 jax.Array 对象的垃圾回收。

  • log: 当 jax.Array 被垃圾回收时记录错误。

  • fatal: 如果 jax.Array 被垃圾回收,则报告致命错误。

默认值为 allow。请注意,并非所有循环都能被检测到。

后端目标#

类型:

字符串

默认值:

''

配置字符串:

'jax_backend_target'

环境变量:

JAX_BACKEND_TARGET

BCOO Cusparse 降级#

类型:

布尔值

默认值:

False

配置字符串:

'jax_bcoo_cusparse_lowering'

环境变量:

JAX_BCOO_CUSPARSE_LOWERING

启用将 BCOO 操作降级到 cuSparse。

捕获常量报告帧#

类型:

整数

默认值:

0

配置字符串:

'jax_captured_constants_report_frames'

环境变量:

JAX_CAPTURED_CONSTANTS_REPORT_FRAMES

为每个捕获的常量报告的堆栈帧数,指示捕获常量的文件和操作。设置为 -1 可打印完整的帧集,设置为 0 可禁用。注意:只有当捕获的常量总大小超过 jax_captured_constants_warn_bytes 时才会生成报告,因为生成报告的成本很高。

捕获常量警告字节#

类型:

整数

默认值:

2000000000

配置字符串:

'jax_captured_constants_warn_bytes'

环境变量:

JAX_CAPTURED_CONSTANTS_WARN_BYTES

在发出警告之前,可作为常量捕获的参数字节数。默认为大约 2GB。设置为 -1 可禁用警告。

检查代理环境变量#

类型:

布尔值

默认值:

True

配置字符串:

'jax_check_proxy_envs'

环境变量:

JAX_CHECK_PROXY_ENVS

检查用户环境变量中的代理变量并发出警告。

检查 Tracer 泄漏#

类型:

布尔值

默认值:

False

配置字符串:

'jax_check_tracer_leaks'

环境变量:

JAX_CHECK_TRACER_LEAKS

在跟踪完成后立即开启对泄漏 Tracer 的检查。启用泄漏检查可能会对性能产生影响:某些缓存将被禁用,并且可能会增加其他开销。此外,请注意某些 Python 调试器可能会导致误报,因此建议在启用泄漏检查时禁用任何调试器。

编译缓存目录#

类型:

字符串

默认值:

配置字符串:

'jax_compilation_cache_dir'

环境变量:

JAX_COMPILATION_CACHE_DIR

缓存路径。优先级:1. 调用 compilation_cache.set_cache_dir()。2. 通过命令行或默认设置的此标志值。

编译缓存预期 PGLE#

类型:

布尔值

默认值:

False

配置字符串:

'jax_compilation_cache_expect_pgle'

环境变量:

JAX_COMPILATION_CACHE_EXPECT_PGLE

如果设置为 True,即使当前未启用 PGLE,也会优先加载使用配置文件数据(即 PGLE 已启用且已对所需执行次数进行分析)编译的编译缓存条目。当未找到首选缓存条目时,将打印警告。

编译缓存键中包含元数据#

类型:

布尔值

默认值:

False

配置字符串:

'jax_compilation_cache_include_metadata_in_key'

环境变量:

JAX_COMPILATION_CACHE_INCLUDE_METADATA_IN_KEY

在编译缓存键中包含元数据,例如文件名和行号。如果为 false,即使函数或文件被移动等,缓存仍然会命中。但是,这意味着从缓存加载的可执行文件可能包含过时的元数据,这可能会在例如配置文件中显示出来。

编译缓存最大大小#

类型:

整数

默认值:

-1

配置字符串:

'jax_compilation_cache_max_size'

环境变量:

JAX_COMPILATION_CACHE_MAX_SIZE

持久编译缓存允许的最大大小(以字节为单位)。设置后,一旦总缓存目录大小超过指定限制,最近最少访问的缓存条目将被删除。如果此值设置为 0,则缓存将被禁用。特殊值 -1 表示没有限制,允许缓存大小无限增长。

编译器详细日志最小操作数#

类型:

整数

默认值:

10

配置字符串:

'jax_compiler_detailed_logging_min_ops'

环境变量:

JAX_COMPILER_DETAILED_LOGGING_MIN_OPS

在 JAX 启用详细编译器日志记录之前,MLIR 操作中的模块应该有多大?此标志的目的是抑制小型/不重要的计算的详细日志记录。

编译器启用 Remat Pass#

类型:

布尔值

默认值:

True

配置字符串:

'jax_compiler_enable_remat_pass'

环境变量:

JAX_COMPILER_ENABLE_REMAT_PASS

用于启用/禁用重物化 HLO pass 的配置。在遇到 OOM 错误时,允许 XLA 自动权衡内存和计算非常有用。但是,您使用 jax.checkpoint 手动操作可能会获得更好的结果。

CPU 集合操作实现#

类型:

枚举值: 'gloo', 'mpi', 'megascale'

默认值:

'gloo'

配置字符串:

'jax_cpu_collectives_implementation'

环境变量:

JAX_CPU_COLLECTIVES_IMPLEMENTATION

在 CPU 上使用的跨进程集合操作实现。必须是 (“gloo”, “mpi”) 之一

CPU 启用异步调度#

类型:

布尔值

默认值:

True

配置字符串:

'jax_cpu_enable_async_dispatch'

环境变量:

JAX_CPU_ENABLE_ASYNC_DISPATCH

仅适用于非并行计算。如果为 False,则以内联方式运行计算,不进行异步调度。

CPU 启用 Gloo 集合操作#

类型:

布尔值

默认值:

False

配置字符串:

'jax_cpu_enable_gloo_collectives'

环境变量:

JAX_CPU_ENABLE_GLOO_COLLECTIVES

已弃用,请改用 jax_cpu_collectives_implementation。

CPU 获取全局拓扑超时(分钟)#

类型:

整数

默认值:

5

配置字符串:

'jax_cpu_get_global_topology_timeout_minutes'

环境变量:

JAX_CPU_GET_GLOBAL_TOPOLOGY_TIMEOUT_MINUTES

获取 CPU 设备全局拓扑的超时时间(分钟);应严格大于 –jax_cpu_get_local_topology_timeout_minutes

CPU 获取本地拓扑超时(分钟)#

类型:

整数

默认值:

2

配置字符串:

'jax_cpu_get_local_topology_timeout_minutes'

环境变量:

JAX_CPU_GET_LOCAL_TOPOLOGY_TIMEOUT_MINUTES

构建全局拓扑时,获取每个 CPU 设备的本地拓扑的超时时间(分钟)。

跨主机传输套接字地址#

类型:

字符串

默认值:

''

配置字符串:

'jax_cross_host_transfer_socket_address'

环境变量:

JAX_CROSS_HOST_TRANSFER_SOCKET_ADDRESS

通过 DCN 进行跨主机设备传输时使用的套接字地址。仅当 PjRt 插件不支持跨主机传输时才需要。

跨主机传输超时(秒)#

类型:

整数

默认值:

配置字符串:

'jax_cross_host_transfer_timeout_seconds'

环境变量:

JAX_CROSS_HOST_TRANSFER_TIMEOUT_SECONDS

跨主机传输大小#

类型:

整数

默认值:

配置字符串:

'jax_cross_host_transfer_transfer_size'

环境变量:

JAX_CROSS_HOST_TRANSFER_TRANSFER_SIZE

跨主机传输地址#

类型:

字符串

默认值:

''

配置字符串:

'jax_cross_host_transport_addresses'

环境变量:

JAX_CROSS_HOST_TRANSPORT_ADDRESSES

用于通过 DCN 进行跨主机设备传输的传输地址的逗号分隔列表。如果未设置,则默认为 [0.0.0.0:0] * 4。

CUDA 可见设备#

类型:

字符串

默认值:

'all'

配置字符串:

'jax_cuda_visible_devices'

环境变量:

JAX_CUDA_VISIBLE_DEVICES

自定义 VJP 禁用形状检查#

类型:

布尔值

默认值:

False

配置字符串:

'jax_custom_vjp_disable_shape_check'

环境变量:

JAX_CUSTOM_VJP_DISABLE_SHAPE_CHECK

禁用 #19009 中的检查以启用一些 custom_vjp 技巧。此功能将在未来版本的 JAX 中默认启用,届时所有使用此标志的行为都将被视为弃用(遵循 API 兼容性策略)。

调试 Infs#

类型:

布尔值

默认值:

False

配置字符串:

'jax_debug_infs'

环境变量:

JAX_DEBUG_INFS

为每个操作添加无穷大检查。当在 JIT 编译计算的输出中检测到无穷大时,调用未编译版本,以尝试更精确地识别产生无穷大的操作。

调试键重用#

类型:

布尔值

默认值:

False

配置字符串:

'jax_debug_key_reuse'

环境变量:

JAX_DEBUG_KEY_REUSE

开启实验性键重用检查。启用此配置后,类型化的 PRNG 键(即使用 jax.random.key() 创建的键)将跟踪其使用情况,并且错误地重用先前使用的键将导致错误。当前启用此功能会导致每次调用 JIT 编译函数时,如果键作为输入或输出,都会产生少量的 Python 开销。

调试日志模块#

类型:

字符串

默认值:

''

配置字符串:

'jax_debug_log_modules'

环境变量:

JAX_DEBUG_LOG_MODULES

逗号分隔的模块名称列表(例如“jax”或“jax._src.xla_bridge,jax._src.dispatch”),用于启用调试日志记录。

调试 NaNs#

类型:

布尔值

默认值:

False

配置字符串:

'jax_debug_nans'

环境变量:

JAX_DEBUG_NANS

为每个操作添加 NaN 检查。当在 JIT 编译计算的输出中检测到 NaN 时,调用未编译版本,以尝试更精确地识别产生 NaN 的操作。

默认设备#

类型:

字符串

默认值:

配置字符串:

'jax_default_device'

环境变量:

JAX_DEFAULT_DEVICE

配置 JAX 操作的默认设备。设置为 Device 对象(例如 jax.devices("cpu")[0])以将其用作 JAX 操作和 jit 调用的默认设备(对多设备计算,例如 pmap 函数调用,没有影响)。设置为 None 可使用系统默认设备。有关设备放置的更多信息,请参阅控制数据和计算在设备上的放置

默认矩阵乘法精度#

类型:

枚举值: 'default', 'high', 'highest', 'bfloat16', 'tensorfloat32', 'float32', 'ANY_F8_ANY_F8_F32', 'ANY_F8_ANY_F8_F32_FAST_ACCUM', 'ANY_F8_ANY_F8_ANY', 'ANY_F8_ANY_F8_ANY_FAST_ACCUM', 'F16_F16_F16', 'F16_F16_F32', 'BF16_BF16_BF16', 'BF16_BF16_F32', 'BF16_BF16_F32_X3', 'BF16_BF16_F32_X6', 'BF16_BF16_F32_X9', 'TF32_TF32_F32', 'TF32_TF32_F32_X3', 'F32_F32_F32', 'F64_F64_F64'

默认值:

配置字符串:

'jax_default_matmul_precision'

环境变量:

JAX_DEFAULT_MATMUL_PRECISION

控制 32 位输入的默认矩阵乘法和卷积精度。

某些平台(如 TPU)为矩阵乘法和卷积计算提供可配置的精度级别,以权衡精度和速度。精度可以针对每个操作进行控制;例如,请参阅 jax.lax.conv_general_dilated()jax.lax.dot() 的文档字符串。但控制操作未给定特定精度时获得的默认行为会很有用。

此选项可用于控制涉及 32 位输入矩阵乘法和卷积计算的默认精度级别。这些级别大致描述了计算标量积的精度。'bfloat16' 选项最快且精度最低;'float32' 类似于完整的 float32 精度;'tensorfloat32' 介于两者之间。

此参数还可用于为执行矩阵乘法的函数(如 jax.lax.dot())指定累积“算法”。要指定算法,请将此选项设置为 DotAlgorithmPreset 的名称。

默认 PRNG 实现#

类型:

枚举值: 'threefry2x32', 'rbg', 'unsafe_rbg'

默认值:

'threefry2x32'

配置字符串:

'jax_default_prng_impl'

环境变量:

JAX_DEFAULT_PRNG_IMPL

选择默认的 PRNG 实现,当在种子生成时未明确提供时使用。

禁用 JIT#

类型:

布尔值

默认值:

False

配置字符串:

'jax_disable_jit'

环境变量:

JAX_DISABLE_JIT

禁用 JIT 编译,直接调用原始 Python。

禁用大多数优化#

类型:

布尔值

默认值:

False

配置字符串:

'jax_disable_most_optimizations'

环境变量:

JAX_DISABLE_MOST_OPTIMIZATIONS

禁用 Vmap Shmap 错误#

类型:

布尔值

默认值:

False

配置字符串:

'jax_disable_vmap_shmap_error'

环境变量:

JAX_DISABLE_VMAP_SHMAP_ERROR

禁用 vmap-of-shmap 中错误检查的临时解决方案。

禁止 Mesh 上下文管理器#

类型:

布尔值

默认值:

False

配置字符串:

'jax_disallow_mesh_context_manager'

环境变量:

JAX_DISALLOW_MESH_CONTEXT_MANAGER

如果设置为 True,尝试将 Mesh 用作上下文管理器将导致 RuntimeError。

分布式调试#

类型:

布尔值

默认值:

False

配置字符串:

'jax_distributed_debug'

环境变量:

JAX_DISTRIBUTED_DEBUG

启用对调试多进程分布式计算有用的日志记录。日志记录以 WARNING 级别使用 logging 执行。

转储 IR 模式#

类型:

字符串

默认值:

'stablehlo'

配置字符串:

'jax_dump_ir_modes'

环境变量:

JAX_DUMP_IR_MODES

以逗号分隔的 IR 转储模式。可以是“stablehlo”(默认)、“jaxpr”或“eqn_count_pprof”(用于 jaxpr 方程计数 pprof 配置文件)。

转储 IR 到#

类型:

字符串

默认值:

''

配置字符串:

'jax_dump_ir_to'

环境变量:

JAX_DUMP_IR_TO

JAX 发出的 IR 应转储为文本文件的路径。如果省略,JAX 将不转储任何 IR。支持特殊值“sponge”以从环境变量 TEST_UNDECLARED_OUTPUTS_DIR 中选择路径。有关控制转储内容的选项,请参阅 jax_dump_ir_modes。

动态形状#

类型:

布尔值

默认值:

False

配置字符串:

'jax_dynamic_shapes'

环境变量:

JAX_DYNAMIC_SHAPES

启用用于暂存动态形状计算的实验性功能。

启用检查#

类型:

布尔值

默认值:

False

配置字符串:

'jax_enable_checks'

环境变量:

JAX_ENABLE_CHECKS

开启 JAX 内部的不变式检查。这会降低速度。

启用编译缓存#

类型:

布尔值

默认值:

True

配置字符串:

'jax_enable_compilation_cache'

环境变量:

JAX_ENABLE_COMPILATION_CACHE

如果设置为 False,无论是否调用 set_cache_dir(),编译缓存都将被禁用。如果设置为 True,路径可以设置为默认值或通过调用 set_cache_dir() 设置。

启用自定义 PRNG#

类型:

布尔值

默认值:

False

配置字符串:

'jax_enable_custom_prng'

环境变量:

JAX_ENABLE_CUSTOM_PRNG

启用内部升级,允许定义自定义伪随机数生成器实现。此功能将在未来版本的 JAX 中默认启用,届时所有使用此标志的行为都将被视为弃用(遵循 API 兼容性策略)。

通过自定义转置启用自定义 VJP#

类型:

布尔值

默认值:

False

配置字符串:

'jax_enable_custom_vjp_by_custom_transpose'

环境变量:

JAX_ENABLE_CUSTOM_VJP_BY_CUSTOM_TRANSPOSE

启用内部升级,通过归约为 jax.custom_jvpjax.custom_transpose 来实现 jax.custom_vjp。此功能将在未来版本的 JAX 中默认启用,届时所有使用此标志的行为都将被视为弃用(遵循 API 兼容性策略)。

启用 PGLE#

类型:

布尔值

默认值:

False

配置字符串:

'jax_enable_pgle'

环境变量:

JAX_ENABLE_PGLE

如果设置为 True 且属性 jax_pgle_profiling_runs 设置为大于 0,则模块将在运行指定次数后重新编译,并将收集到的数据提供给配置文件引导的延迟估计器。

启用 X64#

类型:

布尔值

默认值:

False

配置字符串:

'jax_enable_x64'

环境变量:

JAX_ENABLE_X64

启用 64 位类型的使用

错误检查行为:除法#

类型:

枚举值: 'ignore', 'raise'

默认值:

'ignore'

配置字符串:

'jax_error_checking_behavior_divide'

环境变量:

JAX_ERROR_CHECKING_BEHAVIOR_DIVIDE

指定遇到除零时的行为。选项为“ignore”或“raise”。

错误检查行为:NaN#

类型:

枚举值: 'ignore', 'raise'

默认值:

'ignore'

配置字符串:

'jax_error_checking_behavior_nan'

环境变量:

JAX_ERROR_CHECKING_BEHAVIOR_NAN

指定遇到 NaN 时的行为。选项为“ignore”或“raise”。

错误检查行为:OOB#

类型:

枚举值: 'ignore', 'raise'

默认值:

'ignore'

配置字符串:

'jax_error_checking_behavior_oob'

环境变量:

JAX_ERROR_CHECKING_BEHAVIOR_OOB

指定遇到越界访问时的行为。选项为“ignore”或“raise”。

执行时间优化力度#

类型:

浮点值

默认值:

0.0

配置字符串:

'jax_exec_time_optimization_effort'

环境变量:

JAX_EXEC_TIME_OPTIMIZATION_EFFORT

最小化执行时间的力度(值越高表示力度越大),有效范围为 [-1.0, 1.0]。

实验性不安全 XLA 运行时错误#

类型:

布尔值

默认值:

False

配置字符串:

'jax_experimental_unsafe_xla_runtime_errors'

环境变量:

JAX_EXPERIMENTAL_UNSAFE_XLA_RUNTIME_ERRORS

在 CPU 和 GPU 上为 jax.experimental.checkify.checks 启用 XLA 运行时错误。这些错误是异步的,可能会丢失且可读性不高。但是,它们会使计算崩溃,并使您能够编写可 jit 检查的代码而无需进行 checkify。在 pmap/pjit 下不起作用。

解释缓存未命中#

类型:

布尔值

默认值:

False

配置字符串:

'jax_explain_cache_misses'

环境变量:

JAX_EXPLAIN_CACHE_MISSES

每次主缓存(例如跟踪缓存)未命中时,记录一条解释。日志记录使用 logging 在 WARNING 级别执行。当此选项设置时,日志级别为 WARNING;否则为 DEBUG。

导出调用约定版本#

类型:

整数

默认值:

10

配置字符串:

'jax_export_calling_convention_version'

环境变量:

JAX_EXPORT_CALLING_CONVENTION_VERSION

用于导出的调用约定版本号。此版本号必须在您的部署环境中使用的 tf.XlaCallModule 支持的版本范围内。请参阅 https://jax.net.cn/en/latest/export/shape_poly.html#calling-convention-versions

导出忽略前向兼容性#

类型:

布尔值

默认值:

False

配置字符串:

'jax_export_ignore_forward_compatibility'

环境变量:

JAX_EXPORT_IGNORE_FORWARD_COMPATIBILITY

是否忽略前向兼容性降级规则。请参阅 https://jax.net.cn/en/latest/export/export.html#compatibility-guarantees-for-custom-calls

高动态范围 Gumbel#

类型:

布尔值

默认值:

False

配置字符串:

'jax_high_dynamic_range_gumbel'

环境变量:

JAX_HIGH_DYNAMIC_RANGE_GUMBEL

如果为 True,Gumbel 噪声会抽取两个样本,以更精确地覆盖低概率事件。

HLO 源文件规范化正则表达式#

类型:

字符串

默认值:

配置字符串:

'jax_hlo_source_file_canonicalization_regex'

环境变量:

JAX_HLO_SOURCE_FILE_CANONICALIZATION_REGEX

用于通过删除给定正则表达式来规范化 HLO 指令的 source_path 元数据。如果设置,则对每个 source_file 调用 re.sub(),并删除所有匹配项。这可用于在使用持久编译缓存时避免虚假缓存未命中,该缓存将 HLO 元数据包含在缓存键中。

转储中包含调试信息#

类型:

布尔值

默认值:

True

配置字符串:

'jax_include_debug_info_in_dumps'

环境变量:

JAX_INCLUDE_DEBUG_INFO_IN_DUMPS

确定在转储 IR 代码时是否保留调试符号和位置信息。默认情况下,调试信息将保留在 IR 转储中。为避免暴露源代码和潜在敏感信息,请设置为 false

位置中包含完整堆栈跟踪#

类型:

布尔值

默认值:

True

配置字符串:

'jax_include_full_tracebacks_in_locations'

环境变量:

JAX_INCLUDE_FULL_TRACEBACKS_IN_LOCATIONS

在 JAX 发出的 IR 的 MLIR 位置中包含 Python 堆栈跟踪。

传统 PRNG 键#

类型:

枚举值: 'allow', 'warn', 'error'

默认值:

'allow'

配置字符串:

'jax_legacy_prng_key'

环境变量:

JAX_LEGACY_PRNG_KEY

指定将原始 PRNG 键传递给 jax.random API 时的行为。

日志检查点残差#

类型:

布尔值

默认值:

False

配置字符串:

'jax_log_checkpoint_residuals'

环境变量:

JAX_LOG_CHECKPOINT_RESIDUALS

每次 jax.checkpoint(又称 jax.remat)被部分求值(例如用于自动微分)时,记录一条消息,打印保存了哪些残差。

日志编译#

类型:

布尔值

默认值:

False

配置字符串:

'jax_log_compiles'

环境变量:

JAX_LOG_COMPILES

每次 jitpmap 编译 XLA 计算时,记录一条消息。日志记录使用 logging 在 WARNING 级别执行。当此选项设置时,日志级别为 WARNING;否则为 DEBUG。

日志级别#

类型:

枚举值: 'NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'

默认值:

'NOTSET'

配置字符串:

'jax_logging_level'

环境变量:

JAX_LOGGING_LEVEL

在所有 JAX 日志记录器上设置相应的日志级别。只接受 [“NOTSET”、“DEBUG”、“INFO”、“WARNING”、“ERROR”、“CRITICAL”] 中的字符串值。如果为 None,则不设置日志级别。包括 C++ 日志记录。

内存适配力度#

类型:

浮点值

默认值:

0.0

配置字符串:

'jax_memory_fitting_effort'

环境变量:

JAX_MEMORY_FITTING_EFFORT

最小化内存使用的力度(值越高表示力度越大),有效范围为 [-1.0, 1.0]。

内存适配级别#

类型:

枚举值: 'UNKNOWN', 'O0', 'O1', 'O2', 'O3'

默认值:

'O2'

配置字符串:

'jax_memory_fitting_level'

环境变量:

JAX_MEMORY_FITTING_LEVEL

编译器应尝试使程序适应内存的程度

模拟 GPU 拓扑#

类型:

字符串

默认值:

''

配置字符串:

'jax_mock_gpu_topology'

环境变量:

JAX_MOCK_GPU_TOPOLOGY

在 GPU 客户端中模拟多主机 GPU 拓扑。该值应采用“<切片数量> x <每切片主机数量> x <每主机设备数量>”的形式。空字符串表示关闭模拟。

Mosaic 允许 HLO#

类型:

布尔值

默认值:

False

配置字符串:

'jax_mosaic_allow_hlo'

环境变量:

JAX_MOSAIC_ALLOW_HLO

在 Mosaic 中允许 HLO 方言

可变数组检查#

类型:

布尔值

默认值:

False

配置字符串:

'jax_mutable_array_checks'

环境变量:

JAX_MUTABLE_ARRAY_CHECKS

对可变数组启用错误检查以排除别名。此功能将在未来版本的 JAX 中默认启用,届时所有使用此标志的行为都将被视为弃用(遵循 API 兼容性策略)。

无跟踪#

类型:

布尔值

默认值:

False

配置字符串:

'jax_no_tracing'

环境变量:

JAX_NO_TRACING

不允许 JIT 编译的跟踪。

CPU 设备数量#

类型:

整数

默认值:

-1

配置字符串:

'jax_num_cpu_devices'

环境变量:

JAX_NUM_CPU_DEVICES

要使用的 CPU 设备数量。如果未提供,则使用 XLA 标志 –xla_force_host_platform_device_count 的值。必须在 JAX 初始化之前设置。

NumPy 数据类型提升#

类型:

枚举值: 'standard', 'strict'

默认值:

'standard'

配置字符串:

'jax_numpy_dtype_promotion'

环境变量:

JAX_NUMPY_DTYPE_PROMOTION

指定数组之间操作中隐式类型提升的规则。选项为“standard”或“strict”;在严格模式下,不同强指定数据类型数组之间的二进制操作将导致错误。

NumPy 秩提升#

类型:

枚举值: 'allow', 'warn', 'raise'

默认值:

'allow'

配置字符串:

'jax_numpy_rank_promotion'

环境变量:

JAX_NUMPY_RANK_PROMOTION

控制 NumPy 风格的自动秩提升广播(“allow”、“warn”或“raise”)。

优化级别#

类型:

枚举值: 'UNKNOWN', 'O0', 'O1', 'O2', 'O3'

默认值:

'UNKNOWN'

配置字符串:

'jax_optimization_level'

环境变量:

JAX_OPTIMIZATION_LEVEL

编译器应为执行时间进行优化的程度

Pallas 转储 Promela 到#

类型:

字符串

默认值:

''

配置字符串:

'jax_pallas_dump_promela_to'

环境变量:

JAX_PALLAS_DUMP_PROMELA_TO

如果设置,将内核的 Promela 模型转储到指定目录。该模型可以验证内核是否没有数据竞争、死锁等问题。

Pallas 启用调试检查#

类型:

布尔值

默认值:

False

配置字符串:

'jax_pallas_enable_debug_checks'

环境变量:

JAX_PALLAS_ENABLE_DEBUG_CHECKS

如果设置,pl.debug_check 调用会在运行时检查。否则,它们是空操作。

Pallas 使用 Mosaic GPU#

类型:

布尔值

默认值:

False

配置字符串:

'jax_pallas_use_mosaic_gpu'

环境变量:

JAX_PALLAS_USE_MOSAIC_GPU

如果为 True,将 Pallas 内核降级到实验性 Mosaic GPU 方言,而不是 Triton IR。

Pallas 详细错误#

类型:

布尔值

默认值:

False

配置字符串:

'jax_pallas_verbose_errors'

环境变量:

JAX_PALLAS_VERBOSE_ERRORS

如果为 True,为 Pallas 内核打印详细错误消息。

持久缓存启用 XLA 缓存#

类型:

字符串

默认值:

'xla_gpu_per_fusion_autotune_cache_dir'

配置字符串:

'jax_persistent_cache_enable_xla_caches'

环境变量:

JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES

当持久缓存启用时,额外的 XLA 缓存也将自动启用。此选项可用于配置将启用哪些 XLA 缓存方法。

持久缓存最小编译时间(秒)#

类型:

浮点值

默认值:

1.0

配置字符串:

'jax_persistent_cache_min_compile_time_secs'

环境变量:

JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS

写入持久编译缓存的计算的最小编译时间。可以提高此阈值以减少写入缓存的条目数。

持久缓存最小条目大小(字节)#

类型:

整数

默认值:

0

配置字符串:

'jax_persistent_cache_min_entry_size_bytes'

环境变量:

JAX_PERSISTENT_CACHE_MIN_ENTRY_SIZE_BYTES

持久编译缓存中将缓存的条目的最小大小(以字节为单位):* -1:禁用大小限制并阻止覆盖。* 保留默认值 (0) 以允许覆盖。覆盖通常会确保最小大小对于用于缓存的文件系统是最佳的。* > 0:所需的实际最小大小;无覆盖。

PGLE 聚合百分位#

类型:

整数

默认值:

90

配置字符串:

'jax_pgle_aggregation_percentile'

环境变量:

JAX_PGLE_AGGREGATION_PERCENTILE

使用 PGLE 时,用于聚合设备间性能数据的百分位。

PGLE 分析运行次数#

类型:

整数

默认值:

3

配置字符串:

'jax_pgle_profiling_runs'

环境变量:

JAX_PGLE_PROFILING_RUNS

使用 PGLE 时,模块在重新编译前应进行分析的次数。

Pjrt 客户端创建选项#

类型:

字符串

默认值:

配置字符串:

'jax_pjrt_client_create_options'

环境变量:

JAX_PJRT_CLIENT_CREATE_OPTIONS

以“k1:v1;k2:v2”字符串格式提供给设备平台 PjRt 客户端的一组键值对,作为额外参数。

平台名称#

类型:

字符串

默认值:

''

配置字符串:

'jax_platform_name'

环境变量:

JAX_PLATFORM_NAME

平台#

类型:

字符串

默认值:

配置字符串:

'jax_platforms'

环境变量:

JAX_PLATFORMS

逗号分隔的平台名称列表,指定 JAX 应初始化哪些平台。如果此列表中的任何平台未能成功初始化,则将引发异常并中止程序。列表中的第一个平台将是默认平台。例如,config.jax_platforms=cpu,tpu 表示将初始化 CPU 和 TPU 后端,并且除非另有指定,否则将使用 CPU 后端。如果 TPU 初始化失败,将引发异常。默认情况下,JAX 将尝试初始化所有可用平台,如果 GPU 或 TPU 可用,则默认为 GPU 或 TPU,否则回退到 CPU。

Pmap 无秩约简#

类型:

布尔值

默认值:

True

配置字符串:

'jax_pmap_no_rank_reduction'

环境变量:

JAX_PMAP_NO_RANK_REDUCTION

如果为 True,pmap 分片将具有与其包含数组相同的秩。

Pmap Shmap 合并#

类型:

布尔值

默认值:

False

配置字符串:

'jax_pmap_shmap_merge'

环境变量:

JAX_PMAP_SHMAP_MERGE

如果为 True,pmap 和 shard_map API 将合并。此功能将在未来版本的 JAX 中默认启用,届时所有使用此标志的行为都将被视为弃用(遵循 API 兼容性策略)。

美观打印使用颜色#

类型:

布尔值

默认值:

True

配置字符串:

'jax_pprint_use_color'

环境变量:

JAX_PPRINT_USE_COLOR

启用带有彩色语法高亮显示的 jaxpr 美观打印。

不规则点乘使用不规则点乘指令#

类型:

布尔值

默认值:

False

配置字符串:

'jax_ragged_dot_use_ragged_dot_instruction'

环境变量:

JAX_RAGGED_DOT_USE_RAGGED_DOT_INSTRUCTION

(仅限 TPU)如果为 True,则使用 chlo.ragged_dot 指令进行 ragged_dot() 降级。否则,依赖于 ragged_dot_general_p 降级规则中的展开逻辑。此功能将在未来版本的 JAX 中默认启用,届时所有使用此标志的行为都将被视为弃用(遵循 API 兼容性策略)。

抛出持久缓存错误#

类型:

布尔值

默认值:

False

配置字符串:

'jax_raise_persistent_cache_errors'

环境变量:

JAX_RAISE_PERSISTENT_CACHE_ERRORS

如果为 true,则在读取或写入持久编译缓存时引发的异常将被允许通过,如果未手动捕获,则会中止程序执行。如果为 false,则捕获异常并作为警告引发,允许程序继续执行。默认为 false,因此缓存错误或间歇性问题不会导致致命错误。

随机种子偏移#

类型:

整数

默认值:

0

配置字符串:

'jax_random_seed_offset'

环境变量:

JAX_RANDOM_SEED_OFFSET

所有随机种子的偏移量(例如 jax.random.key() 的参数)。

从缓存键中移除自定义分区指针#

类型:

布尔值

默认值:

False

配置字符串:

'jax_remove_custom_partitioning_ptr_from_cache_key'

环境变量:

JAX_REMOVE_CUSTOM_PARTITIONING_PTR_FROM_CACHE_KEY

如果设置为 True,则在缓存键计算期间,在散列之前从预编译的 stableHLO 中移除自定义分区指针。这是一个可能不安全的标志,只有确定自己要实现什么的用户才应设置它。

ROCm 可见设备#

类型:

字符串

默认值:

'all'

配置字符串:

'jax_rocm_visible_devices'

环境变量:

JAX_ROCM_VISIBLE_DEVICES

主机间共享二进制#

类型:

布尔值

默认值:

False

配置字符串:

'jax_share_binary_between_hosts'

环境变量:

JAX_SHARE_BINARY_BETWEEN_HOSTS

如果设置为 True,编译后的模块将直接在主机之间共享。

主机间共享二进制超时(毫秒)#

类型:

整数

默认值:

1200000

配置字符串:

'jax_share_binary_between_hosts_timeout_ms'

环境变量:

JAX_SHARE_BINARY_BETWEEN_HOSTS_TIMEOUT_MS

编译模块共享的超时时间。

Softmax 自定义 JVP#

类型:

布尔值

默认值:

False

配置字符串:

'jax_softmax_custom_jvp'

环境变量:

JAX_SOFTMAX_CUSTOM_JVP

为 jax.nn.softmax 使用新的 custom_jvp 规则。新规则应改进内存使用和稳定性。设置为 True 以使用新行为。请参阅 jax-ml/jax#15677 此功能将在未来版本的 JAX 中默认启用,届时所有使用此标志的行为都将被视为弃用(遵循 API 兼容性策略)。

Threefry GPU 内核降级#

类型:

布尔值

默认值:

False

配置字符串:

'jax_threefry_gpu_kernel_lowering'

环境变量:

JAX_THREEFRY_GPU_KERNEL_LOWERING

在 GPU 上,将 Threefry PRNG 操作降级为内核实现。这会加快编译时间,但可能会牺牲运行时内存性能。

Threefry 可分区#

类型:

布尔值

默认值:

True

配置字符串:

'jax_threefry_partitionable'

环境变量:

JAX_THREEFRY_PARTITIONABLE

启用 Threefry PRNG 内部实现更改,使其在某些情况下可自动分区。如果不使用此标志,使用标准 jax.random 伪随机数生成可能会导致不必要的通信和/或冗余分布式计算。使用此标志,在某些情况下通信开销会消失。此功能将在未来版本的 JAX 中默认启用,届时所有使用此标志的行为都将被视为弃用(遵循 API 兼容性策略)。

堆栈跟踪过滤#

类型:

枚举值: 'off', 'tracebackhide', 'remove_frames', 'quiet_remove_frames', 'auto'

默认值:

'auto'

配置字符串:

'jax_traceback_filtering'

环境变量:

JAX_TRACEBACK_FILTERING

控制 JAX 如何从堆栈跟踪中过滤掉内部帧。有效值包括: - off:禁用堆栈跟踪过滤。 - auto:如果在足够新的 IPython 下运行,则使用 tracebackhide,否则使用 remove_frames。 - tracebackhide:向隐藏堆栈帧添加 __tracebackhide__ 注解,某些堆栈跟踪打印机支持此功能。 - remove_frames:从堆栈跟踪中移除隐藏帧,并将未过滤的堆栈跟踪作为异常的 __cause__ 添加。 - quiet_remove_frames:从堆栈跟踪中移除隐藏帧,并添加一条简短消息(到异常的 __cause__)描述此情况的发生。

位置中堆栈跟踪限制#

类型:

整数

默认值:

10

配置字符串:

'jax_traceback_in_locations_limit'

环境变量:

JAX_TRACEBACK_IN_LOCATIONS_LIMIT

限制 MLIR 位置中包含的 Python 堆栈跟踪帧数。如果设置为负值,则堆栈跟踪将不受限制。

Tracer 错误堆栈跟踪帧数#

类型:

整数

默认值:

5

配置字符串:

'jax_tracer_error_num_traceback_frames'

环境变量:

JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES

设置 JAX Tracer 错误消息中的堆栈帧数。

传输守卫#

类型:

枚举值: 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'

默认值:

配置字符串:

'jax_transfer_guard'

环境变量:

JAX_TRANSFER_GUARD

选择所有传输的传输守卫级别。此选项仅用于设置;特定方向的传输守卫级别应使用按传输方向的选项读取。默认值为“allow”。

传输守卫:设备到设备#

类型:

枚举值: 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'

默认值:

配置字符串:

'jax_transfer_guard_device_to_device'

环境变量:

JAX_TRANSFER_GUARD_DEVICE_TO_DEVICE

选择设备到设备传输的传输守卫级别。默认值为“allow”。

传输守卫:设备到主机#

类型:

枚举值: 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'

默认值:

配置字符串:

'jax_transfer_guard_device_to_host'

环境变量:

JAX_TRANSFER_GUARD_DEVICE_TO_HOST

选择设备到主机传输的传输守卫级别。默认值为“allow”。

传输守卫:主机到设备#

类型:

枚举值: 'allow', 'log', 'disallow', 'log_explicit', 'disallow_explicit'

默认值:

配置字符串:

'jax_transfer_guard_host_to_device'

环境变量:

JAX_TRANSFER_GUARD_HOST_TO_DEVICE

选择主机到设备传输的传输守卫级别。默认值为“allow”。

使用直接线性化#

类型:

布尔值

默认值:

True

配置字符串:

'jax_use_direct_linearize'

环境变量:

JAX_USE_DIRECT_LINEARIZE

使用直接线性化而不是 JVP 后跟部分求值

使用 Magma#

类型:

枚举值: 'off', 'on', 'auto'

默认值:

'auto'

配置字符串:

'jax_use_magma'

环境变量:

JAX_USE_MAGMA

启用 GPU 上对 MAGMA 支持的 lax.linalg.eig 的实验性支持。有关如何使用此功能的更多详细信息,请参阅 lax.linalg.eig 的文档。

使用 Shardy 分区器#

类型:

布尔值

默认值:

True

配置字符串:

'jax_use_shardy_partitioner'

环境变量:

JAX_USE_SHARDY_PARTITIONER

是否降级到 Shardy。有关更多信息,请参阅迁移指南:https://jax.net.cn/en/latest/shardy_jax_migration.html。此功能将在未来版本的 JAX 中默认启用,届时所有使用此标志的行为都将被视为弃用(遵循 API 兼容性策略)。

使用简化 Jaxpr 常量#

类型:

布尔值

默认值:

False

配置字符串:

'jax_use_simplified_jaxpr_constants'

环境变量:

JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS

启用 Jaxpr 中封闭常量的简化处理。值 True 启用新行为。此标志将短暂存在,以便我们过渡用户。请参阅 jax-ml/jax#29679.DO 请勿依赖此标志。

XLA 后端#

类型:

字符串

默认值:

''

配置字符串:

'jax_xla_backend'

环境变量:

JAX_XLA_BACKEND

XLA 配置文件版本#

类型:

整数

默认值:

0

配置字符串:

'jax_xla_profile_version'

环境变量:

JAX_XLA_PROFILE_VERSION

XLA 编译的可选配置文件版本。仅当 XLA 配置为支持远程编译配置文件功能时,此选项才有意义。

模拟 GPU 进程数量#

类型:

整数

默认值:

0

配置字符串:

'mock_num_gpu_processes'

环境变量:

MOCK_NUM_GPU_PROCESSES

模拟 GPU 客户端中的 JAX 进程数量。值为零时关闭模拟。