编译后的打印和断点#
jax.debug 包提供了一些有用的工具,用于检查编译函数内部的值。
使用 jax.debug.print 和其他调试回调进行调试#
摘要: 在编译后的(例如,经过 jax.jit 或 jax.pmap 装饰的)函数中,使用 jax.debug.print() 将追踪到的数组值打印到标准输出。
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯
对于 jax.grad 和 jax.vmap 等一些转换,您可以使用 Python 内置的 print 函数来打印数值。但是 print 不适用于 jax.jit 或 jax.pmap,因为这些转换会延迟数值计算。所以请改用 jax.debug.print!
语义上,jax.debug.print 大致等同于以下 Python 函数
def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
print(fmt.format(*args, **kwargs))
只是它可以被 JAX 阶段化和转换。有关更多详细信息,请参阅 API 参考。
请注意,fmt 不能是 f-string,因为 f-string 会立即格式化,而对于 jax.debug.print,我们希望延迟格式化到稍后。
何时使用“debug”打印?#
您应该在 JAX 转换(如 jit、vmap 等)中使用 jax.debug.print 来动态(即追踪)打印数组值。对于静态值(如数组形状或数据类型)的打印,您可以使用普通的 Python print 语句。
为什么使用“debug”打印?#
为了调试的目的,jax.debug.print 可以揭示有关计算如何评估的信息。
xs = jnp.arange(3.)
def f(x):
jax.debug.print("x: {}", x)
y = jnp.sin(x)
jax.debug.print("y: {}", y)
return y
jax.vmap(f)(xs)
# Prints: x: 0.0
# x: 1.0
# x: 2.0
# y: 0.0
# y: 0.841471
# y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
# y: 0.0
# x: 1.0
# y: 0.841471
# x: 2.0
# y: 0.9092974
请注意,打印结果的顺序不同!
通过揭示这些内部工作原理,jax.debug.print 的输出不遵守 JAX 的常规语义保证,例如 jax.vmap(f)(xs) 和 jax.lax.map(f, xs) 计算相同的内容(以不同的方式)。但这些评估顺序的细节正是我们在调试时可能想要看到的!
因此,请使用 jax.debug.print 进行调试,而不是在需要语义保证时使用。
更多 jax.debug.print 的示例#
除了上面使用 jit 和 vmap 的示例外,这里还有一些其他需要注意的。
在 jax.pmap 下打印#
在 jax.pmap 装饰后,jax.debug.print 可能会被重新排序!
xs = jnp.arange(2.)
def f(x):
jax.debug.print("x: {}", x)
return x
jax.pmap(f)(xs)
# Prints: x: 0.0
# x: 1.0
# OR
# Prints: x: 1.0
# x: 0.0
在 jax.grad 下打印#
在 jax.grad 下,jax.debug.print 只会在前向传播时打印。
def f(x):
jax.debug.print("x: {}", x)
return x * 2.
jax.grad(f)(1.)
# Prints: x: 1.0
这种行为类似于 Python 内置的 print 在 jax.grad 下的工作方式。但通过在这里使用 jax.debug.print,即使调用者应用了 jax.jit,其行为也是相同的。
要在反向传播时打印,只需使用 jax.custom_vjp
@jax.custom_vjp
def print_grad(x):
return x
def print_grad_fwd(x):
return x, None
def print_grad_bwd(_, x_grad):
jax.debug.print("x_grad: {}", x_grad)
return (x_grad,)
print_grad.defvjp(print_grad_fwd, print_grad_bwd)
def f(x):
x = print_grad(x)
return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0
在其他转换中打印#
jax.debug.print 在 pjit 等其他转换中也有效。
通过 jax.debug.callback 进行更多控制#
事实上,jax.debug.print 是 jax.debug.callback 的一个简单便捷的包装器,您可以直接使用它来更精确地控制字符串格式化,甚至输出的类型。
语义上,jax.debug.callback 大致等同于以下 Python 函数
def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
fun(*args, **kwargs)
return None
与 jax.debug.print 一样,这些回调应该只用于调试输出,如打印或绘图。打印和绘图是无害的,但如果您将其用于其他目的,其行为在转换下可能会让您感到意外。例如,使用 jax.debug.callback 来计时操作是不安全的,因为回调可能会被重新排序和异步(见下文)。
jax.debug.print 的优缺点#
优点#
打印调试简单直观
jax.debug.callback可用于其他无害的副作用
局限性#
添加打印语句是一个手动过程
可能产生性能影响
使用 jax.debug.breakpoint() 进行交互式检查#
摘要: 使用 jax.debug.breakpoint() 来暂停 JAX 程序的执行以检查值。
@jax.jit
def f(x):
y, z = jnp.sin(x), jnp.cos(x)
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution!

jax.debug.breakpoint() 实际上是 jax.debug.callback(...) 的一个应用,它捕获有关调用堆栈的信息。因此,它具有与 jax.debug.print 相同的转换行为(例如,对 jax.debug.breakpoint() 进行 vmap 会在其映射的轴上展开它)。
用法#
在编译后的 JAX 函数中调用 jax.debug.breakpoint() 将在程序命中断点时暂停。您将看到一个类似 pdb 的提示,允许您检查调用堆栈中的值。与 pdb 不同,您将无法单步执行,但您可以恢复它。
调试器命令
help- 打印可用命令p- 评估表达式并打印其结果pp- 评估表达式并漂亮地打印其结果u(p)- 向上移动一个堆栈帧d(own)- 向下移动一个堆栈帧w(here)/bt- 打印回溯l(ist)- 打印代码上下文c(ont(inue))- 恢复程序执行q(uit)/exit- 退出程序(在 TPU 上无效)
示例#
与 jax.lax.cond 的用法#
与 jax.lax.cond 结合使用时,调试器可以成为检测 nan 或 inf 的有用工具。
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 0.) # ==> Pauses during execution!
注意事项#
因为 jax.debug.breakpoint 只是 jax.debug.callback 的一个应用,它具有与 jax.debug.print 相同的注意事项,并有一些额外的提示。
jax.debug.breakpoint比jax.debug.print物化更多中间值,因为它强制物化调用堆栈中的所有值。jax.debug.breakpoint的运行时开销比jax.debug.print更大,因为它可能需要将 JAX 程序中的所有中间值从设备复制到主机。
jax.debug.breakpoint() 的优缺点#
优点#
简单、直观且(在某种程度上)标准
可以同时检查调用堆栈中向上和向下的许多值
缺点#
可能需要使用许多断点才能精确定位错误源
物化许多中间值