调试简介#
您是否遇到了梯度爆炸?NaN 是否让您咬牙切齿?只是想查看计算中的中间值?本节将向您介绍一组内置的 JAX 调试方法,您可以在各种 JAX 转换中使用它们。
概述
使用
jax.debug.print()在jax.jit-、jax.pmap- 和pjit- 装饰的函数中将值打印到标准输出,并使用jax.debug.breakpoint()暂停编译函数的执行以检查调用堆栈中的值。jax.experimental.checkify允许您向 JAX 代码添加jit-able 的运行时错误检查(例如,越界索引)。JAX 提供了配置标志和上下文管理器,可以更轻松地捕获错误。例如,启用
jax_debug_nans标志以自动检测jax.jit编译的代码中何时产生了 NaN,并启用jax_disable_jit标志以禁用 JIT 编译。
jax.debug.print 用于简单检查#
这里有一个经验法则
使用
jax.debug.print()来打印jax.jit()、jax.vmap()等的跟踪(动态)数组值。对于静态值(例如 dtype 和数组形状),请使用 Python 的
print()。
回想一下 JIT 编译 中,当使用 jax.jit() 转换函数时,Python 代码会使用抽象的追踪器(tracer)代替您的数组来执行。因此,Python 的 print() 函数只会打印这个追踪器值。
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
print("print(x) ->", x)
y = jnp.sin(x)
print("print(y) ->", y)
return y
result = f(2.)
print(x) -> JitTracer<~float32[]>
print(y) -> JitTracer<~float32[]>
Python 的 print 在追踪时执行,在运行时值存在之前。如果您想打印实际的运行时值,可以使用 jax.debug.print()。
@jax.jit
def f(x):
jax.debug.print("jax.debug.print(x) -> {x}", x=x)
y = jnp.sin(x)
jax.debug.print("jax.debug.print(y) -> {y}", y=y)
return y
result = f(2.)
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314
同样,在 jax.vmap() 中,使用 Python 的 print 也只会打印追踪器;要打印被映射的值,请使用 jax.debug.print()。
def f(x):
jax.debug.print("jax.debug.print(x) -> {}", x)
y = jnp.sin(x)
jax.debug.print("jax.debug.print(y) -> {}", y)
return y
xs = jnp.arange(3.)
result = jax.vmap(f)(xs)
jax.debug.print(x) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(y) -> 0.9092974066734314
以下是使用 jax.lax.map() 的结果,它是一个顺序映射而不是向量化。
result = jax.lax.map(f, xs)
jax.debug.print(y) -> 0.0
jax.debug.print(x) -> 0.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(x) -> 1.0
jax.debug.print(y) -> 0.9092974066734314
jax.debug.print(x) -> 2.0
请注意顺序不同,因为 jax.vmap() 和 jax.lax.map() 以不同的方式计算相同的结果。调试时,正是执行顺序的细节可能需要您进行检查。
下面是使用 jax.grad() 的示例,其中 jax.debug.print() 只打印前向传播。在这种情况下,行为与 Python 的 print() 类似,但如果您在调用期间应用 jax.jit(),则会保持一致。
def f(x):
jax.debug.print("jax.debug.print(x) -> {}", x)
return x ** 2
result = jax.grad(f)(1.)
jax.debug.print(x) -> 1.0
有时,当参数不相互依赖时,调用 jax.debug.print() 时,在 JAX 转换进行阶段化输出时,可能会以不同的顺序打印它们。如果您需要原始顺序(例如,先是 x: ...,然后是 y: ...),请添加 ordered=True 参数。
例如
@jax.jit
def f(x, y):
jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
return x + y
f(1, 2)
jax.debug.print(x) -> 1
jax.debug.print(y) -> 2
Array(3, dtype=int32, weak_type=True)
要了解有关 jax.debug.print() 及其“Sharp Bits”的更多信息,请参阅 高级调试。
jax.debug.breakpoint 用于 pdb 式调试#
摘要:使用 jax.debug.breakpoint() 暂停 JAX 程序的执行以检查值。
要在调试期间的特定点暂停编译后的 JAX 程序,您可以使用 jax.debug.breakpoint()。提示信息类似于 Python 的 pdb,并且允许您检查调用堆栈中的值。事实上,jax.debug.breakpoint() 是 jax.debug.callback() 的一个应用,它捕获有关调用堆栈的信息。
要打印调试会话期间所有可用的命令,请使用 help 命令。(完整的调试器命令,“Sharp Bits”,其优点和局限性都在 高级调试 中介绍。)
以下是一个调试会话可能的样子示例
@jax.jit
def f(x):
y, z = jnp.sin(x), jnp.cos(x)
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution

对于依赖于值的断点,您可以使用运行时条件,如 jax.lax.cond()。
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
jax.lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 1.) # ==> No breakpoint
Array(2., dtype=float32, weak_type=True)
f(2., 0.) # ==> Pauses during execution
jax.debug.callback 用于在调试时进行更多控制#
jax.debug.print() 和 jax.debug.breakpoint() 都是使用更灵活的 jax.debug.callback() 实现的,它可以通过 Python 回调提供对主机端逻辑执行的更大控制。它与 jax.jit()、jax.vmap()、jax.grad() 以及其他转换兼容(有关更多信息,请参阅 外部回调 中的“回调的种类”表)。
例如
import logging
def log_value(x):
logging.warning(f'Logged value: {x}')
@jax.jit
def f(x):
jax.debug.callback(log_value, x)
return x
f(1.0);
WARNING:root:Logged value: 1.0
此回调与包括 jax.vmap() 和 jax.grad() 在内的其他转换兼容。
x = jnp.arange(5.0)
jax.vmap(f)(x);
WARNING:root:Logged value: 0.0
WARNING:root:Logged value: 1.0
WARNING:root:Logged value: 2.0
WARNING:root:Logged value: 3.0
WARNING:root:Logged value: 4.0
jax.grad(f)(1.0);
WARNING:root:Logged value: 1.0
这使得 jax.debug.callback() 可用于通用调试。
您可以在 外部回调 中了解有关 jax.debug.callback() 和其他 JAX 回调类型的更多信息。
请参阅 编译后的打印和断点 以了解更多信息。
使用 jax.experimental.checkify 进行函数式错误检查#
摘要: Checkify 允许您向 JAX 代码添加 jit-able 的运行时错误检查(例如,越界索引)。使用 checkify.checkify 转换以及类似 assert 的 checkify.check 函数来为 JAX 代码添加运行时检查。
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))
您也可以使用 checkify 自动添加常见的检查。
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
请参阅 checkify 转换指南 以了解更多信息。
使用 JAX 的调试标志抛出 Python 错误#
摘要: 启用 jax_debug_nans 标志以自动检测 jax.jit 编译的代码中何时产生了 NaN(但在 jax.pmap 或 jax.pjit 编译的代码中不会),并启用 jax_disable_jit 标志以禁用 JIT 编译,从而可以使用 print 和 pdb 等传统的 Python 调试工具。
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
请参阅 JAX 调试标志 以了解更多信息。
下一步#
查看 高级调试 以了解更多关于 JAX 调试的信息。