配置选项#
JAX 提供了各种配置选项来自定义其行为。这些选项控制着从数值精度到调试功能的方方面面。
如何使用配置选项#
JAX 配置选项可以通过以下几种方式设置
环境变量(在运行程序之前设置)
export JAX_ENABLE_X64=True python my_program.py
运行时配置(在您的 Python 代码中)
import jax jax.config.update("jax_enable_x64", True)
命令行标志(使用 Abseil)
# In your code: import jax jax.config.parse_flags_with_absl()
# When running: python my_program.py --jax_enable_x64=True
常用配置选项#
以下是一些最常用的配置选项
jax_enable_x64
– 启用 64 位浮点精度jax_disable_jit
– 禁用 JIT 编译以进行调试jax_debug_nans
– 检查 NaN 并在 NaN 上引发错误jax_platforms
– 控制 JAX 将初始化的后端 (CPU/GPU/TPU)jax_numpy_rank_promotion
– 控制自动秩提升行为jax_default_matmul_precision
– 设置矩阵乘法运算的默认精度
所有配置选项#
以下是所有可用的 JAX 配置选项的完整列表
检查 Rep#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'check_rep'
- 环境变量:
CHECK_REP
shard_map 的内部实现细节,请勿使用
Eager 常量折叠#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'eager_constant_folding'
- 环境变量:
EAGER_CONSTANT_FOLDING
在暂存期间尝试常量折叠。
Jax2Tf 结合性扫描规约#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax2tf_associative_scan_reductions'
- 环境变量:
JAX2TF_ASSOCIATIVE_SCAN_REDUCTIONS
JAX 为累积规约原语(cumsum、cumprod、cummax、cummin)提供了两个单独的 lowering 规则。在 CPU 和 GPU 上,它使用 lax.associative_scan,而对于 TPU,它使用 HLO ReduceWindow。后者在 CPU 和 GPU 上的实现速度较慢。默认情况下,jax2tf 使用 TPU lowering。将此标志设置为 True 以使用结合性扫描 lowering 用法,并且仅当它对您的应用程序产生影响时才使用。有关更多详细信息,请参阅 jax2tf README.md。
Jax2Tf 默认原生序列化#
- 类型:
布尔型
- 默认值:
True
- 配置字符串:
'jax2tf_default_native_serialization'
- 环境变量:
JAX2TF_DEFAULT_NATIVE_SERIALIZATION
设置 jax2tf.convert 的 native_serialization 参数的默认值。优先使用参数而不是标志,该标志将来可能会被删除。从 JAX 0.4.31 开始,非原生序列化已被弃用。
Array 垃圾回收保护#
- 类型:
枚举值:
'allow'
,'log'
,'fatal'
- 默认值:
None
- 配置字符串:
'jax_array_garbage_collection_guard'
- 环境变量:
JAX_ARRAY_GARBAGE_COLLECTION_GUARD
为 jax.Array
对象选择垃圾回收保护级别。
此选项可用于控制当 jax.Array
对象被垃圾回收时会发生什么。理想情况下,jax.Array
对象应通过 Python 引用计数而不是垃圾回收来释放,以避免设备内存被数组占用,直到垃圾回收发生。
有效值包括
allow
:不记录jax.Array
对象的垃圾回收。log
:当jax.Array
被垃圾回收时记录错误。fatal
:如果jax.Array
被垃圾回收,则发生致命错误。
默认为 allow
。请注意,并非所有循环都可能被检测到。
后端目标#
- 类型:
字符串
- 默认值:
''
- 配置字符串:
'jax_backend_target'
- 环境变量:
JAX_BACKEND_TARGET
Bcoo Cusparse 降低#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_bcoo_cusparse_lowering'
- 环境变量:
JAX_BCOO_CUSPARSE_LOWERING
启用将 BCOO 操作降低到 cuSparse。
检查代理环境#
- 类型:
布尔型
- 默认值:
True
- 配置字符串:
'jax_check_proxy_envs'
- 环境变量:
JAX_CHECK_PROXY_ENVS
检查用户环境中的代理变量并发出警告。
检查 Tracer 泄漏#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_check_tracer_leaks'
- 环境变量:
JAX_CHECK_TRACER_LEAKS
开启检查 Tracer 泄漏,一旦跟踪完成。启用泄漏检查可能会对性能产生影响:某些缓存被禁用,并且可能会添加其他开销。此外,请注意,某些 Python 调试器可能会导致误报,因此建议在启用泄漏检查时禁用任何调试器。
编译缓存目录#
- 类型:
字符串
- 默认值:
None
- 配置字符串:
'jax_compilation_cache_dir'
- 环境变量:
JAX_COMPILATION_CACHE_DIR
缓存的路径。优先级:1. 调用 compilation_cache.set_cache_dir()。 2. 在命令行或默认情况下设置的此标志的值。
编译缓存期望 Pgle#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_compilation_cache_expect_pgle'
- 环境变量:
JAX_COMPILATION_CACHE_EXPECT_PGLE
如果设置为 True,则即使当前未启用 PGLE,也会优先加载使用性能分析数据编译的编译缓存条目(即,启用了 PGLE 并且对所需数量的执行进行了性能分析)。当找不到首选缓存条目时,将打印警告。
编译缓存包含元数据在键中#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_compilation_cache_include_metadata_in_key'
- 环境变量:
JAX_COMPILATION_CACHE_INCLUDE_METADATA_IN_KEY
在编译缓存键中包含元数据,例如文件名和行号。如果为 false,即使函数或文件被移动等,缓存仍然会命中。但是,这意味着从缓存加载的可执行文件可能具有过时的元数据,这可能会在例如性能分析中显示出来。
编译缓存最大大小#
- 类型:
整数
- 默认值:
-1
- 配置字符串:
'jax_compilation_cache_max_size'
- 环境变量:
JAX_COMPILATION_CACHE_MAX_SIZE
持久编译缓存允许的最大大小(以字节为单位)。设置后,一旦总缓存目录大小超过指定限制,最近最少访问的缓存条目将被删除。如果此值设置为 0,则将禁用缓存。特殊值 -1 表示没有限制,允许缓存大小无限增长。
编译器详细日志记录最小操作数#
- 类型:
整数
- 默认值:
10
- 配置字符串:
'jax_compiler_detailed_logging_min_ops'
- 环境变量:
JAX_COMPILER_DETAILED_LOGGING_MIN_OPS
在 JAX 启用详细编译器日志记录之前,模块在 MLIR 操作中应该有多大?此标志的目的是抑制小型/不重要的计算的详细日志记录。
编译器启用 Remat Pass#
- 类型:
布尔型
- 默认值:
True
- 配置字符串:
'jax_compiler_enable_remat_pass'
- 环境变量:
JAX_COMPILER_ENABLE_REMAT_PASS
配置以启用/禁用 rematerialization HLO pass。当遇到 OOM 错误时,允许 XLA 自动权衡内存和计算非常有用。但是,您可能会使用 jax.checkpoint 手动获得更好的结果
CPU 集体通信实现#
- 类型:
枚举值:
'gloo'
,'mpi'
,'megascale'
- 默认值:
None
- 配置字符串:
'jax_cpu_collectives_implementation'
- 环境变量:
JAX_CPU_COLLECTIVES_IMPLEMENTATION
在 CPU 上使用的跨进程集体通信实现。必须是 (“gloo”, “mpi”) 之一
CPU 启用异步调度#
- 类型:
布尔型
- 默认值:
True
- 配置字符串:
'jax_cpu_enable_async_dispatch'
- 环境变量:
JAX_CPU_ENABLE_ASYNC_DISPATCH
仅适用于非并行计算。如果为 False,则以内联方式运行计算,而无需异步调度。
CPU 启用 Gloo 集体通信#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_cpu_enable_gloo_collectives'
- 环境变量:
JAX_CPU_ENABLE_GLOO_COLLECTIVES
已弃用,请改用 jax_cpu_collectives_implementation。
Cuda 可见设备#
- 类型:
字符串
- 默认值:
'all'
- 配置字符串:
'jax_cuda_visible_devices'
- 环境变量:
JAX_CUDA_VISIBLE_DEVICES
自定义 Vjp 禁用形状检查#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_custom_vjp_disable_shape_check'
- 环境变量:
JAX_CUSTOM_VJP_DISABLE_SHAPE_CHECK
禁用来自 #19009 的检查以启用一些 custom_vjp hacks。这将在未来版本的 JAX 中默认启用,届时该标志的所有用途都将被视为已弃用(遵循 API 兼容性策略)。
调试 Infs#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_debug_infs'
- 环境变量:
JAX_DEBUG_INFS
向每个操作添加 inf 检查。当在 jit 编译的计算的输出上检测到 inf 时,调用未编译的版本,以尝试更精确地识别产生 inf 的操作。
调试密钥重用#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_debug_key_reuse'
- 环境变量:
JAX_DEBUG_KEY_REUSE
开启实验性密钥重用检查。启用此配置后,将跟踪类型化 PRNG 密钥(即使用 jax.random.key() 创建的密钥)的使用情况,并且不正确地重用先前使用的密钥将导致错误。目前,启用此功能会导致每次调用 JIT 编译的函数时,都会产生少量 Python 开销,该函数将密钥作为输入或输出。
调试日志模块#
- 类型:
字符串
- 默认值:
''
- 配置字符串:
'jax_debug_log_modules'
- 环境变量:
JAX_DEBUG_LOG_MODULES
以逗号分隔的模块名称列表(例如“jax”或“jax._src.xla_bridge,jax._src.dispatch”)以启用调试日志记录。
调试 Nans#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_debug_nans'
- 环境变量:
JAX_DEBUG_NANS
向每个操作添加 nan 检查。当在 jit 编译的计算的输出上检测到 nan 时,调用未编译的版本,以尝试更精确地识别产生 nan 的操作。
默认设备#
- 类型:
字符串
- 默认值:
None
- 配置字符串:
'jax_default_device'
- 环境变量:
JAX_DEFAULT_DEVICE
配置 JAX 操作的默认设备。设置为 Device 对象(例如 jax.devices("cpu")[0]
)以将该 Device 用作 JAX 操作和 jit 函数调用的默认设备(对多设备计算(例如,pmapped 函数调用)没有影响)。设置为 None 以使用系统默认设备。有关设备放置的更多信息,请参阅 控制设备上的数据和计算放置。
默认 Matmul 精度#
- 类型:
枚举值:
'default'
,'high'
,'highest'
,'bfloat16'
,'tensorfloat32'
,'float32'
,'ANY_F8_ANY_F8_F32'
,'ANY_F8_ANY_F8_F32_FAST_ACCUM'
,'ANY_F8_ANY_F8_ANY'
,'ANY_F8_ANY_F8_ANY_FAST_ACCUM'
,'F16_F16_F16'
,'F16_F16_F32'
,'BF16_BF16_BF16'
,'BF16_BF16_F32'
,'BF16_BF16_F32_X3'
,'BF16_BF16_F32_X6'
,'BF16_BF16_F32_X9'
,'TF32_TF32_F32'
,'TF32_TF32_F32_X3'
,'F32_F32_F32'
,'F64_F64_F64'
- 默认值:
None
- 配置字符串:
'jax_default_matmul_precision'
- 环境变量:
JAX_DEFAULT_MATMUL_PRECISION
控制 32 位输入的默认 matmul 和 conv 精度。
某些平台(如 TPU)为矩阵乘法和卷积计算提供可配置的精度级别,以速度换取精度。可以为每个操作控制精度;例如,请参阅 jax.lax.conv_general_dilated()
和 jax.lax.dot()
文档字符串。但是,控制在未给出特定精度的情况下获得默认行为可能很有用。
此选项可用于控制 32 位输入上涉及矩阵乘法和卷积的计算的默认精度级别。这些级别大致描述了计算标量积的精度。“bfloat16”选项是最快且精度最低的;“float32”类似于全 float32 精度;“tensorfloat32”是中间值。
此参数还可用于为执行矩阵乘法的函数(如 jax.lax.dot()
)指定累积“算法”。要指定算法,请将此选项设置为 DotAlgorithmPreset
的名称。
默认 Prng 实现#
- 类型:
枚举值:
'threefry2x32'
,'rbg'
,'unsafe_rbg'
- 默认值:
'threefry2x32'
- 配置字符串:
'jax_default_prng_impl'
- 环境变量:
JAX_DEFAULT_PRNG_IMPL
选择默认的 PRNG 实现,当在设定种子时没有显式提供实现时使用。
禁用 Jit#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_disable_jit'
- 环境变量:
JAX_DISABLE_JIT
禁用 JIT 编译,仅调用原始 Python 代码。
禁用大多数优化#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_disable_most_optimizations'
- 环境变量:
JAX_DISABLE_MOST_OPTIMIZATIONS
禁用 Vmap Shmap 错误#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_disable_vmap_shmap_error'
- 环境变量:
JAX_DISABLE_VMAP_SHMAP_ERROR
禁用 vmap-of-shmap 中错误检查的临时性解决方法。
禁止 Mesh 上下文管理器#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_disallow_mesh_context_manager'
- 环境变量:
JAX_DISALLOW_MESH_CONTEXT_MANAGER
如果设置为 True,尝试使用 mesh 作为上下文管理器将导致 RuntimeError。
分布式调试#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_distributed_debug'
- 环境变量:
JAX_DISTRIBUTED_DEBUG
启用对多进程分布式计算进行调试有用的日志记录。日志记录使用 logging 模块,级别为 WARNING。
转储 IR 到#
- 类型:
字符串
- 默认值:
''
- 配置字符串:
'jax_dump_ir_to'
- 环境变量:
JAX_DUMP_IR_TO
JAX 发出的 IR 应该转储为文本文件的路径。如果省略,JAX 将不会转储 IR。支持特殊值 ‘sponge’,以从环境变量 TEST_UNDECLARED_OUTPUTS_DIR 中选择路径。
动态形状#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_dynamic_shapes'
- 环境变量:
JAX_DYNAMIC_SHAPES
启用用于分段输出具有动态形状的计算的实验性功能。
启用检查#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_enable_checks'
- 环境变量:
JAX_ENABLE_CHECKS
开启 JAX 内部的不变性检查。会使速度变慢。
启用编译缓存#
- 类型:
布尔型
- 默认值:
True
- 配置字符串:
'jax_enable_compilation_cache'
- 环境变量:
JAX_ENABLE_COMPILATION_CACHE
如果设置为 False,则无论是否调用了 set_cache_dir(),编译缓存都将被禁用。 如果设置为 True,则路径可以设置为默认值或通过调用 set_cache_dir() 设置。
启用自定义 Prng#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_enable_custom_prng'
- 环境变量:
JAX_ENABLE_CUSTOM_PRNG
启用内部升级,允许用户定义自定义伪随机数生成器实现。 此功能将在未来版本的 JAX 中默认启用,届时所有对此标志的使用都将被视为已弃用(遵循 API 兼容性策略)。
通过自定义转置启用自定义 Vjp#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_enable_custom_vjp_by_custom_transpose'
- 环境变量:
JAX_ENABLE_CUSTOM_VJP_BY_CUSTOM_TRANSPOSE
启用内部升级,通过归约到 jax.custom_jvp 和 jax.custom_transpose 来实现 jax.custom_vjp。 此功能将在未来版本的 JAX 中默认启用,届时所有对此标志的使用都将被视为已弃用(遵循 API 兼容性策略)。
启用空数组#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_enable_empty_arrays'
- 环境变量:
JAX_ENABLE_EMPTY_ARRAYS
允许从单设备数组的空列表创建数组。 这是为了支持 McJAX 中的 MPMD/流水线并行(WIP)。
启用 Pgle#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_enable_pgle'
- 环境变量:
JAX_ENABLE_PGLE
如果设置为 True 且属性 jax_pgle_profiling_runs 设置为大于 0,则模块将在运行指定次数后重新编译,并将收集的数据提供给 profile guided latency estimator。
启用 X64#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_enable_x64'
- 环境变量:
JAX_ENABLE_X64
启用 64 位类型的使用
错误检查行为 - 除法#
- 类型:
枚举值:
'ignore'
,'raise'
- 默认值:
'ignore'
- 配置字符串:
'jax_error_checking_behavior_divide'
- 环境变量:
JAX_ERROR_CHECKING_BEHAVIOR_DIVIDE
指定遇到除零时的行为。选项为 “ignore” 或 “raise”。
错误检查行为 - Nan#
- 类型:
枚举值:
'ignore'
,'raise'
- 默认值:
'ignore'
- 配置字符串:
'jax_error_checking_behavior_nan'
- 环境变量:
JAX_ERROR_CHECKING_BEHAVIOR_NAN
指定遇到 NaN 时的行为。选项为 “ignore” 或 “raise”。
错误检查行为 - Oob#
- 类型:
枚举值:
'ignore'
,'raise'
- 默认值:
'ignore'
- 配置字符串:
'jax_error_checking_behavior_oob'
- 环境变量:
JAX_ERROR_CHECKING_BEHAVIOR_OOB
指定遇到越界访问时的行为。选项为 “ignore” 或 “raise”。
执行时间优化力度#
- 类型:
浮点数
- 默认值:
0.0
- 配置字符串:
'jax_exec_time_optimization_effort'
- 环境变量:
JAX_EXEC_TIME_OPTIMIZATION_EFFORT
最小化执行时间的力度(值越高表示力度越大),有效范围 [-1.0, 1.0]。
实验性不安全 XLA 运行时错误#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_experimental_unsafe_xla_runtime_errors'
- 环境变量:
JAX_EXPERIMENTAL_UNSAFE_XLA_RUNTIME_ERRORS
为 CPU 和 GPU 上的 jax.experimental.checkify.checks 启用 XLA 运行时错误。 这些错误是异步的,可能会丢失,并且可读性不高。 但是,它们会使计算崩溃,并使您能够编写可 JIT 化的检查,而无需 checkify。 在 pmap/pjit 下不起作用。
解释缓存未命中#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_explain_cache_misses'
- 环境变量:
JAX_EXPLAIN_CACHE_MISSES
每次主缓存(例如跟踪缓存)发生未命中时,记录一个解释。日志记录使用 logging 模块。 设置此选项后,日志级别为 WARNING;否则级别为 DEBUG。
导出调用约定版本#
- 类型:
整数
- 默认值:
9
- 配置字符串:
'jax_export_calling_convention_version'
- 环境变量:
JAX_EXPORT_CALLING_CONVENTION_VERSION
用于导出的调用约定版本号。 这必须在您的部署环境中 tf.XlaCallModule 支持的版本范围内。 请参阅 https://jax.net.cn/en/latest/export/shape_poly.html#calling-convention-versions。
导出时忽略向前兼容性#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_export_ignore_forward_compatibility'
- 环境变量:
JAX_EXPORT_IGNORE_FORWARD_COMPATIBILITY
是否忽略向前兼容性降低规则。 请参阅 https://jax.net.cn/en/latest/export/export.html#compatibility-guarantees-for-custom-calls。
高动态范围 Gumbel#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_high_dynamic_range_gumbel'
- 环境变量:
JAX_HIGH_DYNAMIC_RANGE_GUMBEL
如果为 True,gumble 噪声会抽取两个样本,以更高的精度覆盖低概率事件。
Hlo 源文件规范化正则表达式#
- 类型:
字符串
- 默认值:
None
- 配置字符串:
'jax_hlo_source_file_canonicalization_regex'
- 环境变量:
JAX_HLO_SOURCE_FILE_CANONICALIZATION_REGEX
用于通过删除给定的正则表达式来规范化 HLO 指令的 source_path 元数据。 如果设置,则使用给定的正则表达式在每个 source_file 上调用 re.sub(),并删除所有匹配项。 这可以用于在使用持久编译缓存时避免虚假的缓存未命中,持久编译缓存将 HLO 元数据包含在缓存键中。
在转储中包含调试信息#
- 类型:
字符串
- 默认值:
'True'
- 配置字符串:
'jax_include_debug_info_in_dumps'
- 环境变量:
JAX_INCLUDE_DEBUG_INFO_IN_DUMPS
确定在转储 IR 代码时是否保留调试符号和位置信息。 默认情况下,调试信息将保留在 IR 转储中。 为了避免暴露源代码和潜在的敏感信息,请设置为 false
在位置中包含完整的回溯信息#
- 类型:
布尔型
- 默认值:
True
- 配置字符串:
'jax_include_full_tracebacks_in_locations'
- 环境变量:
JAX_INCLUDE_FULL_TRACEBACKS_IN_LOCATIONS
在 JAX 发出的 IR 中的 MLIR 位置中包含 Python 回溯信息。
旧版 Prng 密钥#
- 类型:
枚举值:
'allow'
,'warn'
,'error'
- 默认值:
'allow'
- 配置字符串:
'jax_legacy_prng_key'
- 环境变量:
JAX_LEGACY_PRNG_KEY
指定将原始 PRNG 密钥传递给 jax.random API 时的行为。
记录检查点残差#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_log_checkpoint_residuals'
- 环境变量:
JAX_LOG_CHECKPOINT_RESIDUALS
每次 jax.checkpoint (又名 jax.remat)被部分评估(例如,用于自动微分)时,记录一条消息,打印保存了哪些残差。
记录编译#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_log_compiles'
- 环境变量:
JAX_LOG_COMPILES
每次 jit 或 pmap 编译 XLA 计算时,记录一条消息。 日志记录使用 logging 模块。 设置此选项后,日志级别为 WARNING;否则级别为 DEBUG。
日志记录级别#
- 类型:
枚举值:
'NOTSET'
,'DEBUG'
,'INFO'
,'WARNING'
,'ERROR'
,'CRITICAL'
- 默认值:
'NOTSET'
- 配置字符串:
'jax_logging_level'
- 环境变量:
JAX_LOGGING_LEVEL
在所有 jax 记录器上设置相应的日志记录级别。 仅接受来自 [“NOTSET”, “DEBUG”, “INFO”, “WARNING”, “ERROR”, “CRITICAL”] 的字符串值。 如果为 None,则不会设置日志记录级别。 包括 C++ 日志记录。
内存拟合力度#
- 类型:
浮点数
- 默认值:
0.0
- 配置字符串:
'jax_memory_fitting_effort'
- 环境变量:
JAX_MEMORY_FITTING_EFFORT
最小化内存使用量的力度(值越高表示力度越大),有效范围 [-1.0, 1.0]。
内存拟合级别#
- 类型:
枚举值:
'UNKNOWN'
,'O0'
,'O1'
,'O2'
,'O3'
- 默认值:
'UNKNOWN'
- 配置字符串:
'jax_memory_fitting_level'
- 环境变量:
JAX_MEMORY_FITTING_LEVEL
编译器应尝试使程序适应内存的程度
模拟 GPU 拓扑#
- 类型:
字符串
- 默认值:
''
- 配置字符串:
'jax_mock_gpu_topology'
- 环境变量:
JAX_MOCK_GPU_TOPOLOGY
在 GPU 客户端中模拟多主机 GPU 拓扑。 该值应采用 “<切片数> x <每个切片的主机数> x <每个设备的主机数>” 的形式。 空字符串将关闭模拟。
Mosaic 允许 Hlo#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_mosaic_allow_hlo'
- 环境变量:
JAX_MOSAIC_ALLOW_HLO
允许 Mosaic 中使用 hlo 方言
可变数组检查#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_mutable_array_checks'
- 环境变量:
JAX_MUTABLE_ARRAY_CHECKS
启用用于排除别名的可变数组的错误检查。 此功能将在未来版本的 JAX 中默认启用,届时所有对此标志的使用都将被视为已弃用(遵循 API 兼容性策略)。
禁止跟踪#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_no_tracing'
- 环境变量:
JAX_NO_TRACING
禁止对 JIT 编译进行跟踪。
CPU 设备数量#
- 类型:
整数
- 默认值:
-1
- 配置字符串:
'jax_num_cpu_devices'
- 环境变量:
JAX_NUM_CPU_DEVICES
要使用的 CPU 设备数量。 如果未提供,则使用 XLA 标志 –xla_force_host_platform_device_count 的值。 必须在 JAX 初始化之前设置。
Numpy Dtype 提升#
- 类型:
枚举值:
'standard'
,'strict'
- 默认值:
'standard'
- 配置字符串:
'jax_numpy_dtype_promotion'
- 环境变量:
JAX_NUMPY_DTYPE_PROMOTION
指定用于数组之间操作中的隐式类型提升的规则。 选项为 “standard” 或 “strict”;在 strict 模式下,不同强类型指定的 dtype 数组之间的二元运算将导致错误。
Numpy 秩提升#
- 类型:
枚举值:
'allow'
,'warn'
,'raise'
- 默认值:
'allow'
- 配置字符串:
'jax_numpy_rank_promotion'
- 环境变量:
JAX_NUMPY_RANK_PROMOTION
控制 NumPy 风格的自动秩提升广播(“allow”、“warn” 或 “raise”)。
优化级别#
- 类型:
枚举值:
'UNKNOWN'
,'O0'
,'O1'
,'O2'
,'O3'
- 默认值:
'UNKNOWN'
- 配置字符串:
'jax_optimization_level'
- 环境变量:
JAX_OPTIMIZATION_LEVEL
编译器应优化执行时间的程度
Pallas 转储 Promela 到#
- 类型:
字符串
- 默认值:
''
- 配置字符串:
'jax_pallas_dump_promela_to'
- 环境变量:
JAX_PALLAS_DUMP_PROMELA_TO
如果设置,则将内核的 Promela 模型转储到指定的目录。 该模型可以验证内核是否没有数据竞争、死锁等。
Pallas 启用运行时断言#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_pallas_enable_runtime_assert'
- 环境变量:
JAX_PALLAS_ENABLE_RUNTIME_ASSERT
如果设置,则通过 checkify.check 在内核中启用运行时断言。 否则,运行时断言将被忽略,除非使用 checkify.checkify 进行函数化。
Pallas 使用 Mosaic Gpu#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_pallas_use_mosaic_gpu'
- 环境变量:
JAX_PALLAS_USE_MOSAIC_GPU
如果为 True,则将 Pallas 内核降级到实验性的 Mosaic GPU 方言,而不是 Triton IR。
Pallas 详细错误#
- 类型:
布尔型
- 默认值:
True
- 配置字符串:
'jax_pallas_verbose_errors'
- 环境变量:
JAX_PALLAS_VERBOSE_ERRORS
如果为 True,则为 Pallas 内核打印详细的错误消息。
持久缓存启用 XLA 缓存#
- 类型:
字符串
- 默认值:
'xla_gpu_per_fusion_autotune_cache_dir'
- 配置字符串:
'jax_persistent_cache_enable_xla_caches'
- 环境变量:
JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES
启用持久缓存后,还将自动启用额外的 XLA 缓存。 此选项可用于配置将启用哪些 XLA 缓存方法。
持久缓存最小编译时间(秒)#
- 类型:
浮点数
- 默认值:
1.0
- 配置字符串:
'jax_persistent_cache_min_compile_time_secs'
- 环境变量:
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS
要写入持久编译缓存的计算的最小编译时间。 可以提高此阈值以减少写入缓存的条目数。
持久缓存最小条目大小(字节)#
- 类型:
整数
- 默认值:
0
- 配置字符串:
'jax_persistent_cache_min_entry_size_bytes'
- 环境变量:
JAX_PERSISTENT_CACHE_MIN_ENTRY_SIZE_BYTES
将缓存在持久编译缓存中的条目的最小大小(以字节为单位): * -1:禁用大小限制并防止覆盖。 * 保留为默认值 (0) 以允许覆盖。 覆盖通常将确保最小大小对于用于缓存的文件系统是最佳的。 * > 0:所需的实际最小大小;没有覆盖。
Pgle 聚合百分位数#
- 类型:
整数
- 默认值:
90
- 配置字符串:
'jax_pgle_aggregation_percentile'
- 环境变量:
JAX_PGLE_AGGREGATION_PERCENTILE
使用 PGLE 时,用于聚合设备之间性能数据的百分位数。
Pgle 性能分析运行次数#
- 类型:
整数
- 默认值:
3
- 配置字符串:
'jax_pgle_profiling_runs'
- 环境变量:
JAX_PGLE_PROFILING_RUNS
使用 PGLE 时,模块在重新编译之前应进行性能分析的次数。
Pjrt 客户端创建选项#
- 类型:
字符串
- 默认值:
None
- 配置字符串:
'jax_pjrt_client_create_options'
- 环境变量:
JAX_PJRT_CLIENT_CREATE_OPTIONS
以 “k1:v1;k2:v2” 字符串格式的一组键值对,作为额外的参数提供给设备平台 pjrt 客户端。
平台名称#
- 类型:
字符串
- 默认值:
''
- 配置字符串:
'jax_platform_name'
- 环境变量:
JAX_PLATFORM_NAME
平台#
- 类型:
字符串
- 默认值:
None
- 配置字符串:
'jax_platforms'
- 环境变量:
JAX_PLATFORMS
以逗号分隔的平台名称列表,指定 jax 应初始化的平台。 如果此列表中的任何平台未成功初始化,则会引发异常并中止程序。 列表中的第一个平台将是默认平台。 例如,config.jax_platforms=cpu,tpu 表示将初始化 CPU 和 TPU 后端,并且将使用 CPU 后端,除非另有说明。 如果 TPU 初始化失败,则会引发异常。 默认情况下,jax 将尝试初始化所有可用平台,如果 GPU 或 TPU 可用,则默认为 GPU 或 TPU,否则回退到 CPU。
Pmap 无秩缩减#
- 类型:
布尔型
- 默认值:
True
- 配置字符串:
'jax_pmap_no_rank_reduction'
- 环境变量:
JAX_PMAP_NO_RANK_REDUCTION
如果为 True,则 pmap 分片与其封闭数组具有相同的秩。
Pmap Shmap 合并#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_pmap_shmap_merge'
- 环境变量:
JAX_PMAP_SHMAP_MERGE
如果为 True,则 pmap 和 shard_map API 将被合并。 此功能将在未来版本的 JAX 中默认启用,届时所有对此标志的使用都将被视为已弃用(遵循 API 兼容性策略)。
Pprint 使用颜色#
- 类型:
布尔型
- 默认值:
True
- 配置字符串:
'jax_pprint_use_color'
- 环境变量:
JAX_PPRINT_USE_COLOR
启用使用彩色语法高亮显示进行 jaxpr 漂亮打印。
引发持久缓存错误#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_raise_persistent_cache_errors'
- 环境变量:
JAX_RAISE_PERSISTENT_CACHE_ERRORS
如果为 true,则允许在读取或写入持久编译缓存时引发的异常,如果未手动捕获,则会停止程序执行。 如果为 false,则异常将被捕获并作为警告引发,从而允许程序执行继续。 默认为 false,因此缓存错误或间歇性问题是非致命的。
随机种子偏移量#
- 类型:
整数
- 默认值:
0
- 配置字符串:
'jax_random_seed_offset'
- 环境变量:
JAX_RANDOM_SEED_OFFSET
所有随机种子的偏移量(例如 jax.random.key() 的参数)。
从缓存键中删除自定义分区指针#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_remove_custom_partitioning_ptr_from_cache_key'
- 环境变量:
JAX_REMOVE_CUSTOM_PARTITIONING_PTR_FROM_CACHE_KEY
如果设置为 True,则在缓存键计算期间哈希之前,删除预编译 stableHLO 中存在的自定义分区指针。 这是一个可能不安全的标志,只有确定自己要实现的目标的用户才应设置它。
Rocm 可见设备#
- 类型:
字符串
- 默认值:
'all'
- 配置字符串:
'jax_rocm_visible_devices'
- 环境变量:
JAX_ROCM_VISIBLE_DEVICES
Softmax 自定义 Jvp#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_softmax_custom_jvp'
- 环境变量:
JAX_SOFTMAX_CUSTOM_JVP
为 jax.nn.softmax 使用新的 custom_jvp 规则。 新规则应提高内存使用率和稳定性。 设置为 True 以使用新行为。 请参阅 jax-ml/jax#15677 此功能将在未来版本的 JAX 中默认启用,届时所有对此标志的使用都将被视为已弃用(遵循 API 兼容性策略)。
SPMD 模式#
- 类型:
枚举值:
'allow_all'
,'allow_jit'
- 默认值:
'allow_jit'
- 配置字符串:
'jax_spmd_mode'
- 环境变量:
JAX_SPMD_MODE
决定是否允许对非完全可寻址(即跨多个进程)的 jax.Array
对象执行数学运算。 选项包括
allow_jit
:默认值,pjit
和jax.jit
计算允许在非完全可寻址的jax.Array
对象上执行allow_all
:jnp
、普通数学运算(如a + b
等)、pjit
、jax.jit
和所有其他操作都允许在非完全可寻址的jax.Array
对象上执行。
Threefry Gpu 内核降级#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_threefry_gpu_kernel_lowering'
- 环境变量:
JAX_THREEFRY_GPU_KERNEL_LOWERING
在 GPU 上,将 threefry PRNG 操作降级为内核实现。 这可以加快编译时间,但可能会增加运行时内存成本。
Threefry 可分区化#
- 类型:
布尔型
- 默认值:
True
- 配置字符串:
'jax_threefry_partitionable'
- 环境变量:
JAX_THREEFRY_PARTITIONABLE
启用内部 threefry PRNG 实现更改,使其在某些情况下自动可分区。 如果没有此标志,使用标准的 jax.random 伪随机数生成可能会导致额外的通信和/或冗余的分布式计算。 使用此标志,通信开销在某些情况下会消失。 此功能将在未来版本的 JAX 中默认启用,届时所有对此标志的使用都将被视为已弃用(遵循 API 兼容性策略)。
回溯过滤#
- 类型:
枚举值:
'off'
,'tracebackhide'
,'remove_frames'
,'quiet_remove_frames'
,'auto'
- 默认值:
'auto'
- 配置字符串:
'jax_traceback_filtering'
- 环境变量:
JAX_TRACEBACK_FILTERING
控制 JAX 如何从回溯中过滤内部帧。 有效值包括: - off
:禁用回溯过滤。 - auto
:如果在足够新的 IPython 下运行,则使用 tracebackhide
,否则使用 remove_frames
。 - tracebackhide
:将 __tracebackhide__
注释添加到隐藏的堆栈帧,某些回溯打印机支持此功能。 - remove_frames
:从回溯中删除隐藏的帧,并将未过滤的回溯添加为异常的 __cause__
。 - quiet_remove_frames
:从回溯中删除隐藏的帧,并添加一条简短消息(到异常的 __cause__
),描述已发生这种情况。
位置中的回溯限制#
- 类型:
整数
- 默认值:
10
- 配置字符串:
'jax_traceback_in_locations_limit'
- 环境变量:
JAX_TRACEBACK_IN_LOCATIONS_LIMIT
限制 MLIR 位置中包含的 Python 回溯帧的帧数。 如果设置为负值,则回溯将不受限制。
Tracer 错误回溯帧数#
- 类型:
整数
- 默认值:
5
- 配置字符串:
'jax_tracer_error_num_traceback_frames'
- 环境变量:
JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES
设置 JAX tracer 错误消息中的堆栈帧数。
传输保护#
- 类型:
枚举值:
'allow'
,'log'
,'disallow'
,'log_explicit'
,'disallow_explicit'
- 默认值:
None
- 配置字符串:
'jax_transfer_guard'
- 环境变量:
JAX_TRANSFER_GUARD
为所有传输选择传输保护级别。 此选项为只设置;特定方向的传输保护级别应使用每个传输方向的选项读取。 默认为 “allow”。
传输保护 - 设备到设备#
- 类型:
枚举值:
'allow'
,'log'
,'disallow'
,'log_explicit'
,'disallow_explicit'
- 默认值:
None
- 配置字符串:
'jax_transfer_guard_device_to_device'
- 环境变量:
JAX_TRANSFER_GUARD_DEVICE_TO_DEVICE
选择设备到设备传输的传输保护级别。 默认为 “allow”。
传输保护 - 设备到主机#
- 类型:
枚举值:
'allow'
,'log'
,'disallow'
,'log_explicit'
,'disallow_explicit'
- 默认值:
None
- 配置字符串:
'jax_transfer_guard_device_to_host'
- 环境变量:
JAX_TRANSFER_GUARD_DEVICE_TO_HOST
选择设备到主机传输的传输保护级别。 默认为 “allow”。
传输保护 - 主机到设备#
- 类型:
枚举值:
'allow'
,'log'
,'disallow'
,'log_explicit'
,'disallow_explicit'
- 默认值:
None
- 配置字符串:
'jax_transfer_guard_host_to_device'
- 环境变量:
JAX_TRANSFER_GUARD_HOST_TO_DEVICE
选择主机到设备传输的传输保护级别。 默认为 “allow”。
使用直接线性化#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_use_direct_linearize'
- 环境变量:
JAX_USE_DIRECT_LINEARIZE
使用直接线性化而不是 JVP,后跟部分评估
使用 Magma#
- 类型:
枚举值:
'off'
,'on'
,'auto'
- 默认值:
'auto'
- 配置字符串:
'jax_use_magma'
- 环境变量:
JAX_USE_MAGMA
启用对 GPU 上 MAGMA 支持的 lax.linalg.eig 的实验性支持。 有关如何使用此功能的更多详细信息,请参阅 lax.linalg.eig 的文档。
使用 Shardy 分区器#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'jax_use_shardy_partitioner'
- 环境变量:
JAX_USE_SHARDY_PARTITIONER
是否降级到 Shardy。 Shardy 是 MLIR 的一个新的开源传播框架。 目前,Shardy 在 JAX 中是实验性的。 请参阅 www.github.com/openxla/shardy 此功能将在未来版本的 JAX 中默认启用,届时所有对此标志的使用都将被视为已弃用(遵循 API 兼容性策略)。
Xla 后端#
- 类型:
字符串
- 默认值:
''
- 配置字符串:
'jax_xla_backend'
- 环境变量:
JAX_XLA_BACKEND
Xla 性能分析版本#
- 类型:
整数
- 默认值:
0
- 配置字符串:
'jax_xla_profile_version'
- 环境变量:
JAX_XLA_PROFILE_VERSION
XLA 编译的可选性能分析版本。 仅当 XLA 配置为支持远程编译性能分析功能时,此选项才有意义。
模拟 GPU 进程数#
- 类型:
整数
- 默认值:
0
- 配置字符串:
'mock_num_gpu_processes'
- 环境变量:
MOCK_NUM_GPU_PROCESSES
模拟 GPU 客户端中的 JAX 进程数。 值为零将关闭模拟。
Mosaic 使用 Python 流水线#
- 类型:
布尔型
- 默认值:
False
- 配置字符串:
'mosaic_use_python_pipeline'
- 环境变量:
MOSAIC_USE_PYTHON_PIPELINE
当调用 as_tpu_kernel 时(对于 Pallas,这发生在 JAX 降级时),从 Python 运行初始 Mosaic MLIR 传递,而不是稍后在 XLA 中运行。