JAX 调试标志#
JAX 提供了一些标志和上下文管理器,可以更轻松地捕获错误。
jax_debug_nans 配置选项和上下文管理器#
摘要: 启用 jax_debug_nans 标志,可以自动检测 jax.jit 编译的代码中何时生成了 NaN(但在 jax.pmap 或 jax.pjit 编译的代码中则不行)。
jax_debug_nans 是一个 JAX 标志,启用后,会在检测到 NaN 时自动引发错误。它对 JIT 编译的代码有特殊处理——当从 JIT 编译的函数中检测到 NaN 输出时,该函数将被重新以即时执行(即不编译)的方式运行,并会在生成 NaN 的具体原始操作处抛出错误。
用法#
如果您想追踪函数或梯度中 NaN 的发生位置,可以通过以下方式启用 NaN 检查器:
设置
JAX_DEBUG_NANS=True环境变量;在您的主文件中靠近顶部添加
jax.config.update("jax_debug_nans", True);在您的主文件中添加
jax.config.parse_flags_with_absl(),然后使用类似--jax_debug_nans=True的命令行标志来设置该选项;
示例#
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
jax_debug_nans 的优点和局限性#
优点#
易于应用
精确检测 NaN 的生成位置
抛出标准的 Python 异常,并且兼容 PDB 命令行调试
局限性#
与
jax.pmap或jax.pjit不兼容重新以即时执行方式运行函数可能速度较慢
对误报(例如,有意生成的 NaN)也会报错
jax_disable_jit 配置选项和上下文管理器#
摘要: 启用 jax_disable_jit 标志以禁用 JIT 编译,从而可以使用传统的 Python 调试工具,如 print 和 pdb。
jax_disable_jit 是一个 JAX 标志,启用后,它会在整个 JAX 中禁用 JIT 编译(包括在 jax.lax.cond 和 jax.lax.scan 等控制流函数中)。
用法#
您可以通过以下方式禁用 JIT 编译:
设置
JAX_DISABLE_JIT=True环境变量;在您的主文件中靠近顶部添加
jax.config.update("jax_disable_jit", True);在您的主文件中添加
jax.config.parse_flags_with_absl(),然后使用类似--jax_disable_jit=True的命令行标志来设置该选项;
示例#
import jax
jax.config.update("jax_disable_jit", True)
def f(x):
y = jnp.log(x)
if jnp.isnan(y):
breakpoint()
return y
jax.jit(f)(-2.) # ==> Enters PDB breakpoint!
jax_disable_jit 的优点和局限性#
优点#
易于应用
启用 Python 内置的
breakpoint和print功能抛出标准的 Python 异常,并且兼容 PDB 命令行调试
局限性#
与
jax.pmap或jax.pjit不兼容不进行 JIT 编译运行函数可能会很慢