配置选项#
JAX 提供各种配置选项来自定义其行为。这些选项控制从数值精度到调试功能的方方面面。
如何使用配置选项#
JAX 配置选项可以通过以下几种方式进行设置:
环境变量(在运行程序之前设置)
export JAX_ENABLE_X64=True python my_program.py
运行时配置(在您的 Python 代码中)
import jax jax.config.update("jax_enable_x64", True)
命令行标志(使用 Abseil)
# In your code: import jax jax.config.parse_flags_with_absl()
# When running: python my_program.py --jax_enable_x64=True
常用配置选项#
以下是一些最常用的配置选项:
jax_enable_x64– 启用 64 位浮点精度jax_disable_jit– 禁用 JIT 编译进行调试jax_debug_nans– 检查 NaN 并就此抛出错误jax_platforms– 控制 JAX 将初始化哪些后端(CPU/GPU/TPU)jax_numpy_rank_promotion– 控制自动秩提升行为jax_default_matmul_precision– 设置矩阵乘法运算的默认精度
所有配置选项#
以下是所有可用 JAX 配置选项的完整列表:
Check Vma#
- 类型:
bool- 默认值:
False- 配置字符串:
'check_vma'- 环境变量:
CHECK_VMA
shard_map 的内部实现细节,请勿使用
Eager Constant Folding#
- 类型:
bool- 默认值:
False- 配置字符串:
'eager_constant_folding'- 环境变量:
EAGER_CONSTANT_FOLDING
在暂存期间尝试常量折叠。
Jax2Tf Associative Scan Reductions#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax2tf_associative_scan_reductions'- 环境变量:
JAX2TF_ASSOCIATIVE_SCAN_REDUCTIONS
JAX 具有两个独立的累积归约原语(cumsum、cumprod、cummax、cummin)的降低规则。在 CPU 和 GPU 上,它使用 lax.associative_scan;而在 TPU 上,它使用 HLO ReduceWindow。后者在 CPU 和 GPU 上的实现速度较慢。默认情况下,jax2tf 使用 TPU 降低规则。将此标志设置为 True 以使用 associative scan 降低用法,并且仅当它对您的应用程序产生影响时才这样做。有关更多详细信息,请参阅 jax2tf README.md。
Jax2Tf Default Native Serialization#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax2tf_default_native_serialization'- 环境变量:
JAX2TF_DEFAULT_NATIVE_SERIALIZATION
将 native_serialization 参数的默认值设置为 jax2tf.convert。优先使用该参数而不是此标志,因为该标志可能会在未来版本中被移除。从 JAX 0.4.31 开始,非原生序列化已被弃用。
Array Garbage Collection Guard#
- 类型:
枚举值:
'allow','log','fatal'- 默认值:
无- 配置字符串:
'jax_array_garbage_collection_guard'- 环境变量:
JAX_ARRAY_GARBAGE_COLLECTION_GUARD
为 jax.Array 对象选择垃圾回收保护级别。
此选项可用于控制当 jax.Array 对象被垃圾回收时会发生什么。理想情况下,jax.Array 对象应通过 Python 引用计数而不是垃圾回收来释放,以避免设备内存被数组持有直到垃圾回收发生。
有效值包括:
allow:不记录jax.Array对象的垃圾回收。log:当jax.Array被垃圾回收时记录错误。fatal:如果jax.Array被垃圾回收,则会发生致命错误。
默认值为 allow。请注意,并非所有循环都可能被检测到。
Backend Target#
- 类型:
str- 默认值:
''- 配置字符串:
'jax_backend_target'- 环境变量:
JAX_BACKEND_TARGET
Bcoo Cusparse Lowering#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_bcoo_cusparse_lowering'- 环境变量:
JAX_BCOO_CUSPARSE_LOWERING
启用将 BCOO 操作降低到 cuSparse。
Captured Constants Report Frames#
- 类型:
int- 默认值:
0- 配置字符串:
'jax_captured_constants_report_frames'- 环境变量:
JAX_CAPTURED_CONSTANTS_REPORT_FRAMES
为每个捕获的常量报告的堆栈帧数,指示常量被捕获的文件和操作。设置为 -1 以打印完整的帧集,或设置为 0 以禁用。注意:仅当捕获的常量总数超过 jax_captured_constants_warn_bytes 时才会生成报告,因为生成报告的成本很高。
Captured Constants Warn Bytes#
- 类型:
int- 默认值:
2000000000- 配置字符串:
'jax_captured_constants_warn_bytes'- 环境变量:
JAX_CAPTURED_CONSTANTS_WARN_BYTES
在发出警告之前,可能被捕获为常量的参数字节数。默认值约为 2GB。设置为 -1 以禁用发出警告。
Check Proxy Envs#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_check_proxy_envs'- 环境变量:
JAX_CHECK_PROXY_ENVS
检查用户环境变量中的代理变量并发出警告。
Check Tracer Leaks#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_check_tracer_leaks'- 环境变量:
JAX_CHECK_TRACER_LEAKS
在跟踪完成时立即开始检查跟踪器泄漏。启用泄漏检查可能会影响性能:一些缓存将被禁用,并可能增加其他开销。此外,请注意,某些 Python 调试器可能导致误报,因此建议在启用泄漏检查时禁用任何调试器。
Collectives Common Channel Id#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_collectives_common_channel_id'- 环境变量:
JAX_COLLECTIVES_COMMON_CHANNEL_ID
集体操作是否应使用公共通道 ID?临时功能标志。
Compilation Cache Dir#
- 类型:
str- 默认值:
无- 配置字符串:
'jax_compilation_cache_dir'- 环境变量:
JAX_COMPILATION_CACHE_DIR
缓存的路径。优先级:1. 调用 compilation_cache.set_cache_dir()。2. 在命令行或默认设置的此标志的值。
Compilation Cache Expect Pgle#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_compilation_cache_expect_pgle'- 环境变量:
JAX_COMPILATION_CACHE_EXPECT_PGLE
如果设置为 True,将优先加载使用配置文件(即启用了 PGLE 并执行了所需次数的分析)编译的编译缓存条目,即使 PGLE 当前未启用。如果找不到首选缓存条目,将打印警告。
Compilation Cache Include Metadata In Key#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_compilation_cache_include_metadata_in_key'- 环境变量:
JAX_COMPILATION_CACHE_INCLUDE_METADATA_IN_KEY
在编译缓存键中包含元数据,例如文件名和行号。如果设置为 false,即使函数或文件已移动等,缓存仍然会命中。但是,这意味着从缓存加载的可执行文件可能包含过时的元数据,这可能会出现在例如配置文件中。
Compilation Cache Max Size#
- 类型:
int- 默认值:
-1- 配置字符串:
'jax_compilation_cache_max_size'- 环境变量:
JAX_COMPILATION_CACHE_MAX_SIZE
持久编译缓存允许的最大大小(以字节为单位)。设置后,一旦缓存目录的总大小超过指定限制,将删除最近最少访问的缓存条目。如果此值设置为 0,则禁用缓存。特殊值 -1 表示无限制,允许缓存大小无限增长。
Compiler Detailed Logging Min Ops#
- 类型:
int- 默认值:
10- 配置字符串:
'jax_compiler_detailed_logging_min_ops'- 环境变量:
JAX_COMPILER_DETAILED_LOGGING_MIN_OPS
在 JAX 启用详细编译器日志记录之前,MLIR 操作应该有多大?此标志的目的是抑制对小型/不感兴趣的计算的详细日志记录。
Compiler Enable Remat Pass#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_compiler_enable_remat_pass'- 环境变量:
JAX_COMPILER_ENABLE_REMAT_PASS
配置启用/禁用重构(rematerialization)HLO 传递。这有助于允许 XLA 在遇到 OOM 错误时自动权衡内存和计算。但是,您可能通过 jax.checkpoint 手动获得更好的结果。
Cpu Collectives Implementation#
- 类型:
枚举值:
'gloo','mpi','megascale'- 默认值:
'gloo'- 配置字符串:
'jax_cpu_collectives_implementation'- 环境变量:
JAX_CPU_COLLECTIVES_IMPLEMENTATION
CPU 上使用的跨进程集体实现。必须是 (“gloo”、“mpi”) 之一。
Cpu Enable Async Dispatch#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_cpu_enable_async_dispatch'- 环境变量:
JAX_CPU_ENABLE_ASYNC_DISPATCH
仅适用于非并行计算。如果为 False,则在没有异步分派的情况下内联运行计算。
Cpu Get Global Topology Timeout Minutes#
- 类型:
int- 默认值:
5- 配置字符串:
'jax_cpu_get_global_topology_timeout_minutes'- 环境变量:
JAX_CPU_GET_GLOBAL_TOPOLOGY_TIMEOUT_MINUTES
获取 CPU 设备全局拓扑的超时时间(分钟);应严格大于 –jax_cpu_get_local_topology_timeout_minutes。
Cpu Get Local Topology Timeout Minutes#
- 类型:
int- 默认值:
2- 配置字符串:
'jax_cpu_get_local_topology_timeout_minutes'- 环境变量:
JAX_CPU_GET_LOCAL_TOPOLOGY_TIMEOUT_MINUTES
构建全局拓扑时获取每个 CPU 设备局部拓扑的超时时间(分钟)。
Cross Host Transfer Socket Address#
- 类型:
str- 默认值:
''- 配置字符串:
'jax_cross_host_transfer_socket_address'- 环境变量:
JAX_CROSS_HOST_TRANSFER_SOCKET_ADDRESS
通过 DCN 进行跨主机设备传输使用的套接字地址。仅当 PjRt 插件不支持跨主机传输时才需要。
Cross Host Transfer Timeout Seconds#
- 类型:
int- 默认值:
无- 配置字符串:
'jax_cross_host_transfer_timeout_seconds'- 环境变量:
JAX_CROSS_HOST_TRANSFER_TIMEOUT_SECONDS
Cross Host Transfer Transfer Size#
- 类型:
int- 默认值:
无- 配置字符串:
'jax_cross_host_transfer_transfer_size'- 环境变量:
JAX_CROSS_HOST_TRANSFER_TRANSFER_SIZE
Cross Host Transport Addresses#
- 类型:
str- 默认值:
''- 配置字符串:
'jax_cross_host_transport_addresses'- 环境变量:
JAX_CROSS_HOST_TRANSPORT_ADDRESSES
用于通过 DCN 进行跨主机设备传输的传输地址的逗号分隔列表。如果未设置,则默认为 [0.0.0.0:0] * 4。
Cuda Visible Devices#
- 类型:
str- 默认值:
'all'- 配置字符串:
'jax_cuda_visible_devices'- 环境变量:
JAX_CUDA_VISIBLE_DEVICES
Custom Vjp Disable Shape Check#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_custom_vjp_disable_shape_check'- 环境变量:
JAX_CUSTOM_VJP_DISABLE_SHAPE_CHECK
禁用 #19009 的检查以启用一些 custom_vjp 技巧。这将在未来的 JAX 版本中默认启用,届时所有使用此标志的情况都将被视为已弃用(遵循 API 兼容性策略)。
Debug Infs#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_debug_infs'- 环境变量:
JAX_DEBUG_INFS
为每个操作添加 inf 检查。当检测到 jit 编译计算的输出存在 inf 时,将调用未编译的版本,以尝试更精确地识别产生 inf 的操作。
Debug Key Reuse#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_debug_key_reuse'- 环境变量:
JAX_DEBUG_KEY_REUSE
启用实验性的密钥重用检查。启用此配置后,将跟踪类型化 PRNG 密钥(即使用 jax.random.key() 创建的密钥)的使用情况,并且重复使用已使用的密钥将导致错误。目前启用此功能会在每次调用带有密钥作为输入或输出的 JIT 编译函数时产生少量 Python 开销。
Debug Log Modules#
- 类型:
str- 默认值:
''- 配置字符串:
'jax_debug_log_modules'- 环境变量:
JAX_DEBUG_LOG_MODULES
要启用调试日志记录的模块名称的逗号分隔列表(例如,“jax”或“jax._src.xla_bridge,jax._src.dispatch”)。
Debug Nans#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_debug_nans'- 环境变量:
JAX_DEBUG_NANS
为每个操作添加 nan 检查。当检测到 jit 编译计算的输出存在 nan 时,将调用未编译的版本,以尝试更精确地识别产生 nan 的操作。
Default Device#
- 类型:
str- 默认值:
无- 配置字符串:
'jax_default_device'- 环境变量:
JAX_DEFAULT_DEVICE
配置 JAX 操作的默认设备。设置为 Device 对象(例如 jax.devices("cpu")[0])以使用该设备作为 JAX 操作和 jit 编译函数调用的默认设备(对多设备计算,例如 pmapped 函数调用没有影响)。设置为 None 以使用系统默认设备。有关设备放置的更多信息,请参阅 控制数据和计算在设备上的放置。
Default Matmul Precision#
- 类型:
枚举值:
'default','high','highest','bfloat16','tensorfloat32','float32','ANY_F8_ANY_F8_F32','ANY_F8_ANY_F8_F32_FAST_ACCUM','ANY_F8_ANY_F8_ANY','ANY_F8_ANY_F8_ANY_FAST_ACCUM','F16_F16_F16','F16_F16_F32','BF16_BF16_BF16','BF16_BF16_F32','BF16_BF16_F32_X3','BF16_BF16_F32_X6','BF16_BF16_F32_X9','TF32_TF32_F32','TF32_TF32_F32_X3','F32_F32_F32','F64_F64_F64'- 默认值:
无- 配置字符串:
'jax_default_matmul_precision'- 环境变量:
JAX_DEFAULT_MATMUL_PRECISION
控制 32 位输入的默认矩阵乘法和卷积精度。
某些平台(如 TPU)为矩阵乘法和卷积计算提供可配置的精度级别,以在准确性和速度之间进行权衡。可以为每个操作控制精度;例如,请参阅 jax.lax.conv_general_dilated() 和 jax.lax.dot() 的文档字符串。但是,控制未指定特定精度的操作时获得的默认行为很有用。
此选项可用于控制 32 位输入上矩阵乘法和卷积计算的默认精度级别。这些级别大致描述了标量乘积的计算精度。“bfloat16”选项是最快也是最不精确的;“float32”类似于完整的 float32 精度;“tensorfloat32”介于两者之间。
此参数还可用于为执行矩阵乘法的函数(如 jax.lax.dot())指定累积“算法”。要指定算法,请将此选项设置为 DotAlgorithmPreset 的名称。
Default Prng Impl#
- 类型:
枚举值:
'threefry2x32','rbg','unsafe_rbg'- 默认值:
'threefry2x32'- 配置字符串:
'jax_default_prng_impl'- 环境变量:
JAX_DEFAULT_PRNG_IMPL
选择默认的 PRNG 实现,当在播种时未显式提供时使用。
Disable Jit#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_disable_jit'- 环境变量:
JAX_DISABLE_JIT
禁用 JIT 编译,直接调用原始 Python。
Disable Most Optimizations#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_disable_most_optimizations'- 环境变量:
JAX_DISABLE_MOST_OPTIMIZATIONS
Disable Vmap Shmap Error#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_disable_vmap_shmap_error'- 环境变量:
JAX_DISABLE_VMAP_SHMAP_ERROR
临时解决方法,用于禁用 vmap-of-shmap 中的错误检查。
Disallow Mesh Context Manager#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_disallow_mesh_context_manager'- 环境变量:
JAX_DISALLOW_MESH_CONTEXT_MANAGER
如果设置为 True,尝试将 mesh 用作上下文管理器将导致 RuntimeError。
Distributed Debug#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_distributed_debug'- 环境变量:
JAX_DISTRIBUTED_DEBUG
启用有助于调试多进程分布式计算的日志记录。日志记录使用 logging 在 WARNING 级别执行。
Dump Ir Modes#
- 类型:
str- 默认值:
'stablehlo'- 配置字符串:
'jax_dump_ir_modes'- 环境变量:
JAX_DUMP_IR_MODES
用于转储 IR 的逗号分隔模式。可以是“stablehlo”(默认)、“jaxpr”或“eqn_count_pprof”(用于 jaxpr 方程计数 pprof 配置文件)。
Dump Ir To#
- 类型:
str- 默认值:
''- 配置字符串:
'jax_dump_ir_to'- 环境变量:
JAX_DUMP_IR_TO
JAX 发出的 IR 将被转储为文本文件的路径。如果省略,JAX 将不会转储任何 IR。支持特殊值“sponge”以从环境变量 TEST_UNDECLARED_OUTPUTS_DIR 中选择路径。有关转储内容的选项,请参阅 jax_dump_ir_modes。
Dynamic Shapes#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_dynamic_shapes'- 环境变量:
JAX_DYNAMIC_SHAPES
启用具有动态形状的计算暂存的实验性功能。
Enable Checks#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_enable_checks'- 环境变量:
JAX_ENABLE_CHECKS
启用 JAX 内部的不变性检查。会降低速度。
Enable Compilation Cache#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_enable_compilation_cache'- 环境变量:
JAX_ENABLE_COMPILATION_CACHE
如果设置为 False,则无论是否调用 set_cache_dir(),编译缓存都将被禁用。如果设置为 True,则路径可以设置为默认值或通过调用 set_cache_dir()。
Enable Custom Prng#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_enable_custom_prng'- 环境变量:
JAX_ENABLE_CUSTOM_PRNG
启用内部升级,允许用户定义自定义伪随机数生成器实现。这将在未来的 JAX 版本中默认启用,届时所有使用此标志的情况都将被视为已弃用(遵循 API 兼容性策略)。
Enable Custom Vjp By Custom Transpose#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_enable_custom_vjp_by_custom_transpose'- 环境变量:
JAX_ENABLE_CUSTOM_VJP_BY_CUSTOM_TRANSPOSE
启用内部升级,通过降低到 jax.custom_jvp 和 jax.custom_transpose 来实现 jax.custom_vjp。这将在未来的 JAX 版本中默认启用,届时所有使用此标志的情况都将被视为已弃用(遵循 API 兼容性策略)。
Enable Pgle#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_enable_pgle'- 环境变量:
JAX_ENABLE_PGLE
如果设置为 True 并且属性 jax_pgle_profiling_runs 设置大于 0,则模块将在运行指定次数后重新编译,并提供收集到的数据给配置文件引导的延迟估计器。
Enable Recoverability#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_enable_recoverability'- 环境变量:
JAX_ENABLE_RECOVERABILITY
允许多控制器 JAX 作业继续运行,即使某些任务失败。
Enable X64#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_enable_x64'- 环境变量:
JAX_ENABLE_X64
启用 64 位类型的使用
Error Checking Behavior Divide#
- 类型:
枚举值:
'ignore','raise'- 默认值:
'ignore'- 配置字符串:
'jax_error_checking_behavior_divide'- 环境变量:
JAX_ERROR_CHECKING_BEHAVIOR_DIVIDE
指定遇到除零错误时的行为。选项为“ignore”或“raise”。
Error Checking Behavior Nan#
- 类型:
枚举值:
'ignore','raise'- 默认值:
'ignore'- 配置字符串:
'jax_error_checking_behavior_nan'- 环境变量:
JAX_ERROR_CHECKING_BEHAVIOR_NAN
指定遇到 NaN 时的行为。选项为“ignore”或“raise”。
Error Checking Behavior Oob#
- 类型:
枚举值:
'ignore','raise'- 默认值:
'ignore'- 配置字符串:
'jax_error_checking_behavior_oob'- 环境变量:
JAX_ERROR_CHECKING_BEHAVIOR_OOB
指定遇到越界访问时的行为。选项为“ignore”或“raise”。
Exec Time Optimization Effort#
- 类型:
float- 默认值:
0.0- 配置字符串:
'jax_exec_time_optimization_effort'- 环境变量:
JAX_EXEC_TIME_OPTIMIZATION_EFFORT
最小化执行时间的精力(越高表示精力越多),有效范围 [-1.0, 1.0]。
Experimental Unsafe Xla Runtime Errors#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_experimental_unsafe_xla_runtime_errors'- 环境变量:
JAX_EXPERIMENTAL_UNSAFE_XLA_RUNTIME_ERRORS
为 jax.experimental.checkify.checks 在 CPU 和 GPU 上启用 XLA 运行时错误。这些错误是异步的,可能会丢失且不太易读。但是,它们会中断计算,并允许您编写可 JIT 编译的检查,而无需 checkify。在 pmap/pjit 下不起作用。
Explain Cache Misses#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_explain_cache_misses'- 环境变量:
JAX_EXPLAIN_CACHE_MISSES
每次主要缓存(例如,跟踪缓存)发生未命中时,记录解释。日志记录使用 logging 执行。当此选项设置时,日志级别为 WARNING;否则级别为 DEBUG。
Explicit X64 Dtypes#
- 类型:
枚举值:
WARN,ERROR,ALLOW- 默认值:
<ExplicitX64Mode.WARN: 1>- 配置字符串:
'jax_explicit_x64_dtypes'- 环境变量:
JAX_EXPLICIT_X64_DTYPES
如果设置为 ALLOW,即使 enable_x64 为 false,也会尊重显式指定的 64 位类型。如果设置为 WARN,将发出警告;如果设置为 ERROR,将引发错误。
Export Calling Convention Version#
- 类型:
int- 默认值:
10- 配置字符串:
'jax_export_calling_convention_version'- 环境变量:
JAX_EXPORT_CALLING_CONVENTION_VERSION
用于导出的调用约定版本号。这必须在您部署环境使用的 tf.XlaCallModule 支持的版本范围内。请参阅 https://jax.net.cn/en/latest/export/shape_poly.html#calling-convention-versions。
Export Ignore Forward Compatibility#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_export_ignore_forward_compatibility'- 环境变量:
JAX_EXPORT_IGNORE_FORWARD_COMPATIBILITY
是否忽略向前兼容性降低规则。请参阅 https://jax.net.cn/en/latest/export/export.html#compatibility-guarantees-for-custom-calls。
High Dynamic Range Gumbel#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_high_dynamic_range_gumbel'- 环境变量:
JAX_HIGH_DYNAMIC_RANGE_GUMBEL
如果为 True,则 gumble 噪声抽取两个样本以更精确地覆盖低概率事件。
Hlo Source File Canonicalization Regex#
- 类型:
str- 默认值:
无- 配置字符串:
'jax_hlo_source_file_canonicalization_regex'- 环境变量:
JAX_HLO_SOURCE_FILE_CANONICALIZATION_REGEX
用于通过删除给定正则表达式来规范化 HLO 指令的 source_path 元数据。如果设置,将对每个 source_file 调用 re.sub() 匹配给定正则表达式,并删除所有匹配项。这可以用来避免在使用包含 HLO 元数据的持久编译缓存时出现不必要的缓存未命中。
Include Debug Info In Dumps#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_include_debug_info_in_dumps'- 环境变量:
JAX_INCLUDE_DEBUG_INFO_IN_DUMPS
确定在转储 IR 代码时是否保留调试符号和位置信息。默认情况下,调试信息将保留在 IR 转储中。为避免泄露源代码和敏感信息,请设置为 false。
Include Full Tracebacks In Locations#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_include_full_tracebacks_in_locations'- 环境变量:
JAX_INCLUDE_FULL_TRACEBACKS_IN_LOCATIONS
在 JAX 发出的 IR 的 MLIR 位置中包含 Python 堆栈跟踪。
Legacy Prng Key#
- 类型:
枚举值:
ALLOW,WARN,ERROR- 默认值:
<LegacyPrngKeyState.ALLOW: 'allow'>- 配置字符串:
'jax_legacy_prng_key'- 环境变量:
JAX_LEGACY_PRNG_KEY
指定将原始 PRNG 密钥传递给 jax.random API 时的行为。
Log Checkpoint Residuals#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_log_checkpoint_residuals'- 环境变量:
JAX_LOG_CHECKPOINT_RESIDUALS
每次 jax.checkpoint(也称为 jax.remat)被部分求值时(例如,用于自动微分)记录一条消息,打印保存了哪些残差。
Log Compiles#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_log_compiles'- 环境变量:
JAX_LOG_COMPILES
每次 jit 或 pmap 编译 XLA 计算时记录一条消息。日志记录使用 logging 执行。当此选项设置时,日志级别为 WARNING;否则级别为 DEBUG。
Logging Level#
- 类型:
枚举值:
'NOTSET','DEBUG','INFO','WARNING','ERROR','CRITICAL'- 默认值:
'NOTSET'- 配置字符串:
'jax_logging_level'- 环境变量:
JAX_LOGGING_LEVEL
在所有 jax 日志记录器上设置相应的日志级别。仅接受 [“NOTSET”, “DEBUG”, “INFO”, “WARNING”, “ERROR”, “CRITICAL”] 中的字符串值。如果为 None,则不设置日志级别。包括 C++ 日志记录。
Memory Fitting Effort#
- 类型:
float- 默认值:
0.0- 配置字符串:
'jax_memory_fitting_effort'- 环境变量:
JAX_MEMORY_FITTING_EFFORT
最小化内存使用的精力(越高表示精力越多),有效范围 [-1.0, 1.0]。
Memory Fitting Level#
- 类型:
枚举值:
'UNKNOWN','O0','O1','O2','O3'- 默认值:
'O2'- 配置字符串:
'jax_memory_fitting_level'- 环境变量:
JAX_MEMORY_FITTING_LEVEL
编译器应在多大程度上尝试使程序适应内存。
Mock Gpu Topology#
- 类型:
str- 默认值:
''- 配置字符串:
'jax_mock_gpu_topology'- 环境变量:
JAX_MOCK_GPU_TOPOLOGY
在 GPU 客户端中模拟多主机 GPU 拓扑。该值应为“<切片数> x <每切片主机数> x <每主机设备数>”的格式。空字符串关闭模拟。
Mosaic Allow Hlo#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_mosaic_allow_hlo'- 环境变量:
JAX_MOSAIC_ALLOW_HLO
允许 Mosaic 中的 hlo 方言。
Mutable Array Checks#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_mutable_array_checks'- 环境变量:
JAX_MUTABLE_ARRAY_CHECKS
启用对可变数组的错误检查,以排除别名。这将在未来的 JAX 版本中默认启用,届时所有使用此标志的情况都将被视为已弃用(遵循 API 兼容性策略)。
No Execution#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_no_execution'- 环境变量:
JAX_NO_EXECUTION
禁止 JAX 执行。
No Tracing#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_no_tracing'- 环境变量:
JAX_NO_TRACING
禁止 JIT 编译的跟踪。
NumCpu Devices#
- 类型:
int- 默认值:
-1- 配置字符串:
'jax_num_cpu_devices'- 环境变量:
JAX_NUM_CPU_DEVICES
要使用的 CPU 设备数量。如果未提供,则使用 XLA 标志 –xla_force_host_platform_device_count 的值。必须在 JAX 初始化之前设置。
Numpy Dtype Promotion#
- 类型:
枚举值:
STANDARD,STRICT- 默认值:
<NumpyDtypePromotion.STANDARD: 'standard'>- 配置字符串:
'jax_numpy_dtype_promotion'- 环境变量:
JAX_NUMPY_DTYPE_PROMOTION
指定用于数组之间操作的隐式类型提升规则。选项为“standard”或“strict”;在 strict 模式下,不同强指定 dtype 的数组之间的二进制操作将导致错误。
Numpy Rank Promotion#
- 类型:
枚举值:
'allow','warn','raise'- 默认值:
'allow'- 配置字符串:
'jax_numpy_rank_promotion'- 环境变量:
JAX_NUMPY_RANK_PROMOTION
控制 NumPy 风格的自动秩提升广播(“allow”、“warn”或“raise”)。
Optimization Level#
- 类型:
枚举值:
'UNKNOWN','O0','O1','O2','O3'- 默认值:
'UNKNOWN'- 配置字符串:
'jax_optimization_level'- 环境变量:
JAX_OPTIMIZATION_LEVEL
编译器应为执行时间进行优化的程度。
Pallas Dump Promela To#
- 类型:
str- 默认值:
''- 配置字符串:
'jax_pallas_dump_promela_to'- 环境变量:
JAX_PALLAS_DUMP_PROMELA_TO
如果设置,会将内核的 Promela 模型转储到指定目录。该模型可以验证内核没有数据竞争、死锁等。
Pallas Enable Debug Checks#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_pallas_enable_debug_checks'- 环境变量:
JAX_PALLAS_ENABLE_DEBUG_CHECKS
如果设置,则在运行时检查 pl.debug_check 调用。否则,它们是无操作。
Pallas Use Mosaic Gpu#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_pallas_use_mosaic_gpu'- 环境变量:
JAX_PALLAS_USE_MOSAIC_GPU
如果为 True,则将 Pallas 内核降低到实验性的 Mosaic GPU 方言,而不是 Triton IR。
Pallas Verbose Errors#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_pallas_verbose_errors'- 环境变量:
JAX_PALLAS_VERBOSE_ERRORS
如果为 True,则为 Pallas 内核打印详细的错误消息。
Persistent Cache Enable Xla Caches#
- 类型:
str- 默认值:
'xla_gpu_per_fusion_autotune_cache_dir'- 配置字符串:
'jax_persistent_cache_enable_xla_caches'- 环境变量:
JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES
当持久缓存启用时,还将自动启用额外的 XLA 缓存。此选项可用于配置将启用哪些 XLA 缓存方法。
Persistent Cache Min Compile Time Secs#
- 类型:
float- 默认值:
1.0- 配置字符串:
'jax_persistent_cache_min_compile_time_secs'- 环境变量:
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS
写入持久编译缓存的计算的最小编译时间。此阈值可以提高,以减少写入缓存的条目数量。
Persistent Cache Min Entry Size Bytes#
- 类型:
int- 默认值:
0- 配置字符串:
'jax_persistent_cache_min_entry_size_bytes'- 环境变量:
JAX_PERSISTENT_CACHE_MIN_ENTRY_SIZE_BYTES
将在持久编译缓存中缓存的条目的最小大小(以字节为单位):* -1:禁用大小限制并防止覆盖。* 保留默认值(0)以允许覆盖。覆盖通常会确保最小大小对于用于缓存的文件系统是最佳的。* > 0:实际所需的最小大小;无覆盖。
Pgle Aggregation Percentile#
- 类型:
int- 默认值:
90- 配置字符串:
'jax_pgle_aggregation_percentile'- 环境变量:
JAX_PGLE_AGGREGATION_PERCENTILE
使用 PGLE 时,用于聚合设备之间性能数据的百分位数。
Pgle Profiling Runs#
- 类型:
int- 默认值:
3- 配置字符串:
'jax_pgle_profiling_runs'- 环境变量:
JAX_PGLE_PROFILING_RUNS
使用 PGLE 进行模块分析的次数,然后再进行重新编译。
Pjrt Client Create Options#
- 类型:
str- 默认值:
无- 配置字符串:
'jax_pjrt_client_create_options'- 环境变量:
JAX_PJRT_CLIENT_CREATE_OPTIONS
一组键值对,格式为“k1:v1;k2:v2”字符串,作为额外参数提供给设备平台 pjrt 客户端。
Platform Name#
- 类型:
str- 默认值:
''- 配置字符串:
'jax_platform_name'- 环境变量:
JAX_PLATFORM_NAME
Platforms#
- 类型:
str- 默认值:
无- 配置字符串:
'jax_platforms'- 环境变量:
JAX_PLATFORMS
平台名称的逗号分隔列表,指定 JAX 应初始化哪些平台。如果此列表中的任何平台未成功初始化,将引发异常并中止程序。列表中的第一个平台将是默认平台。例如,config.jax_platforms=cpu,tpu 表示将初始化 CPU 和 TPU 后端,并且除非另行指定,否则将使用 CPU 后端。如果 TPU 初始化失败,将引发异常。默认情况下,jax 将尝试初始化所有可用平台,并在可用时默认为 GPU 或 TPU,否则回退到 CPU。
Pmap No Rank Reduction#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_pmap_no_rank_reduction'- 环境变量:
JAX_PMAP_NO_RANK_REDUCTION
[已弃用] 如果为 True,则 pmap 分片与它们所属的数组具有相同的秩。设置为 False 已弃用,未来所有 pmap 调用都将在没有秩缩减的情况下进行。
Pmap Shmap Merge#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_pmap_shmap_merge'- 环境变量:
JAX_PMAP_SHMAP_MERGE
如果为 True,则将合并 pmap 和 shard_map API。这将在未来的 JAX 版本中默认启用,届时所有使用此标志的情况都将被视为已弃用(遵循 API 兼容性策略)。
Pprint Use Color#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_pprint_use_color'- 环境变量:
JAX_PPRINT_USE_COLOR
启用带彩色语法高亮的 jaxpr 美化打印。
Ragged Dot Use Ragged Dot Instruction#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_ragged_dot_use_ragged_dot_instruction'- 环境变量:
JAX_RAGGED_DOT_USE_RAGGED_DOT_INSTRUCTION
(仅限 TPU) 如果为 True,则将 chlo.ragged_dot 指令用于 ragged_dot() 降低。否则,依赖于 ragged_dot_general_p 的降低规则中的回滚逻辑。这将在未来的 JAX 版本中默认启用,届时所有使用此标志的情况都将被视为已弃用(遵循 API 兼容性策略)。
Raise Persistent Cache Errors#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_raise_persistent_cache_errors'- 环境变量:
JAX_RAISE_PERSISTENT_CACHE_ERRORS
如果为 true,则允许读取或写入持久编译缓存时引发的异常,并在未手动捕获时终止程序执行。如果为 false,则会捕获异常并以警告形式引发,允许程序继续执行。默认为 false,因此缓存错误或间歇性问题不会致命。
Random Seed Offset#
- 类型:
int- 默认值:
0- 配置字符串:
'jax_random_seed_offset'- 环境变量:
JAX_RANDOM_SEED_OFFSET
所有随机种子(例如 jax.random.key() 的参数)的偏移量。
Refs To Pins#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_refs_to_pins'- 环境变量:
JAX_REFS_TO_PINS
在 HLO 中将 refs 降低为 pinned buffers。这将在未来的 JAX 版本中默认启用,届时所有使用此标志的情况都将被视为已弃用(遵循 API 兼容性策略)。
Remove Custom Partitioning Ptr From Cache Key#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_remove_custom_partitioning_ptr_from_cache_key'- 环境变量:
JAX_REMOVE_CUSTOM_PARTITIONING_PTR_FROM_CACHE_KEY
如果设置为 True,则在计算缓存键期间对预编译的 stableHLO 进行哈希处理之前,删除其中存在的自定义分区指针。这是一个可能不安全的标志,只有那些确切知道自己在做什么的用户才应该设置它。
Remove Size One Mesh Axis From Type#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_remove_size_one_mesh_axis_from_type'- 环境变量:
JAX_REMOVE_SIZE_ONE_MESH_AXIS_FROM_TYPE
从 ShapedArray.sharding 中删除大小为 1 的 mesh 轴。这将在未来的 JAX 版本中默认启用,届时所有使用此标志的情况都将被视为已弃用(遵循 API 兼容性策略)。
Rocm Visible Devices#
- 类型:
str- 默认值:
'all'- 配置字符串:
'jax_rocm_visible_devices'- 环境变量:
JAX_ROCM_VISIBLE_DEVICES
Safer Randint#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_safer_randint'- 环境变量:
JAX_SAFER_RANDINT
对 8 位和 16 位 dtype 使用更安全的 randint 算法。这将在未来的 JAX 版本中默认启用,届时所有使用此标志的情况都将被视为已弃用(遵循 API 兼容性策略)。
Softmax Custom Jvp#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_softmax_custom_jvp'- 环境变量:
JAX_SOFTMAX_CUSTOM_JVP
为 jax.nn.softmax 使用新的 custom_jvp 规则。新规则应提高内存使用量和稳定性。设置为 True 以使用新行为。请参阅 jax-ml/jax#15677。这将在未来的 JAX 版本中默认启用,届时所有使用此标志的情况都将被视为已弃用(遵循 API 兼容性策略)。
Threefry Gpu Kernel Lowering#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_threefry_gpu_kernel_lowering'- 环境变量:
JAX_THREEFRY_GPU_KERNEL_LOWERING
在 GPU 上,将 threefry PRNG 操作降低到内核实现。这可以加快编译速度,但可能会增加运行时内存成本。
Threefry Partitionable#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_threefry_partitionable'- 环境变量:
JAX_THREEFRY_PARTITIONABLE
启用内部 threefry PRNG 实现更改,使其在某些情况下可以自动分区。没有此标志,使用标准的 jax.random 伪随机数生成可能会导致不必要的通信和/或冗余的分布式计算。有了这个标志,在某些情况下通信开销就会消失。这将在未来的 JAX 版本中默认启用,届时所有使用此标志的情况都将被视为已弃用(遵循 API 兼容性策略)。
Traceback Filtering#
- 类型:
枚举值:
'off','tracebackhide','remove_frames','quiet_remove_frames','auto'- 默认值:
'auto'- 配置字符串:
'jax_traceback_filtering'- 环境变量:
JAX_TRACEBACK_FILTERING
控制 JAX 如何过滤掉堆栈跟踪中的内部帧。有效值包括:- off:禁用堆栈跟踪过滤。- auto:如果运行在足够新的 IPython 下,则使用 tracebackhide,否则使用 remove_frames。- tracebackhide:向隐藏的堆栈帧添加 __tracebackhide__ 注释,一些堆栈跟踪打印器支持此功能。- remove_frames:从堆栈跟踪中删除隐藏的帧,并将未过滤的堆栈跟踪添加为异常的 __cause__。- quiet_remove_frames:从堆栈跟踪中删除隐藏的帧,并在(异常的 __cause__ 中)添加一条简短消息,说明已发生这种情况。
Traceback In Locations Limit#
- 类型:
int- 默认值:
10- 配置字符串:
'jax_traceback_in_locations_limit'- 环境变量:
JAX_TRACEBACK_IN_LOCATIONS_LIMIT
限制包含在 MLIR 位置中的 Python 堆栈跟踪帧的数量。如果设置为负值,则不限制堆栈跟踪。
Tracer Error Num Traceback Frames#
- 类型:
int- 默认值:
5- 配置字符串:
'jax_tracer_error_num_traceback_frames'- 环境变量:
JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES
设置 JAX 跟踪器错误消息中的堆栈帧数。
Transfer Guard#
- 类型:
枚举值:
'allow','log','disallow','log_explicit','disallow_explicit'- 默认值:
无- 配置字符串:
'jax_transfer_guard'- 环境变量:
JAX_TRANSFER_GUARD
为所有传输选择传输保护级别。此选项仅用于设置;特定方向的传输保护级别应使用特定方向的选项读取。默认为“allow”。
Transfer Guard Device To Device#
- 类型:
枚举值:
'allow','log','disallow','log_explicit','disallow_explicit'- 默认值:
无- 配置字符串:
'jax_transfer_guard_device_to_device'- 环境变量:
JAX_TRANSFER_GUARD_DEVICE_TO_DEVICE
为设备到设备传输选择传输保护级别。默认为“allow”。
Transfer Guard Device To Host#
- 类型:
枚举值:
'allow','log','disallow','log_explicit','disallow_explicit'- 默认值:
无- 配置字符串:
'jax_transfer_guard_device_to_host'- 环境变量:
JAX_TRANSFER_GUARD_DEVICE_TO_HOST
为设备到主机传输选择传输保护级别。默认为“allow”。
Transfer Guard Host To Device#
- 类型:
枚举值:
'allow','log','disallow','log_explicit','disallow_explicit'- 默认值:
无- 配置字符串:
'jax_transfer_guard_host_to_device'- 环境变量:
JAX_TRANSFER_GUARD_HOST_TO_DEVICE
为主机到设备传输选择传输保护级别。默认为“allow”。
Use Direct Linearize#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_use_direct_linearize'- 环境变量:
JAX_USE_DIRECT_LINEARIZE
使用直接线性化代替 JVP 后跟部分求值。
Use Magma#
- 类型:
枚举值:
'off','on','auto'- 默认值:
'auto'- 配置字符串:
'jax_use_magma'- 环境变量:
JAX_USE_MAGMA
启用对 GPU 上 lax.linalg.eig 的 MAGMA 后端的实验性支持。有关如何使用此功能的更多详细信息,请参阅 lax.linalg.eig 的文档。
Use Shardy Partitioner#
- 类型:
bool- 默认值:
True- 配置字符串:
'jax_use_shardy_partitioner'- 环境变量:
JAX_USE_SHARDY_PARTITIONER
是否降低到 Shardy。有关更多信息,请参阅迁移指南:https://jax.net.cn/en/latest/shardy_jax_migration.html。这将在未来的 JAX 版本中默认启用,届时所有使用此标志的情况都将被视为已弃用(遵循 API 兼容性策略)。
Use Simplified Jaxpr Constants#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_use_simplified_jaxpr_constants'- 环境变量:
JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS
启用对 Jaxpr 中闭包常量的处理进行简化。值为 True 可启用新行为。此标志将仅短暂存在,在我们迁移用户时。请参阅 jax-ml/jax#29679.DO 请勿依赖此标志。
Vjp3#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_vjp3'- 环境变量:
JAX_VJP3
在 jax.vjp 中使用新的反向传播代码。这将在未来的 JAX 版本中默认启用,届时所有使用此标志的情况都将被视为已弃用(遵循 API 兼容性策略)。
Vmap Primitive#
- 类型:
bool- 默认值:
False- 配置字符串:
'jax_vmap_primitive'- 环境变量:
JAX_VMAP_PRIMITIVE
将 vmap 设为 hijax 原语。这将在未来的 JAX 版本中默认启用,届时所有使用此标志的情况都将被视为已弃用(遵循 API 兼容性策略)。
Xla Backend#
- 类型:
str- 默认值:
''- 配置字符串:
'jax_xla_backend'- 环境变量:
JAX_XLA_BACKEND
Xla Profile Version#
- 类型:
int- 默认值:
0- 配置字符串:
'jax_xla_profile_version'- 环境变量:
JAX_XLA_PROFILE_VERSION
XLA 编译的可选配置文件版本。仅当 XLA 配置支持远程编译配置文件功能时才有意义。
Mock Num Gpu Processes#
- 类型:
int- 默认值:
0- 配置字符串:
'mock_num_gpu_processes'- 环境变量:
MOCK_NUM_GPU_PROCESSES
模拟 GPU 客户端中的 JAX 进程数。值为零会关闭模拟。