编译打印和断点#
jax.debug
包提供了一些有用的工具,用于检查编译函数内部的值。
使用 jax.debug.print
和其他调试回调函数进行调试#
总结: 使用 jax.debug.print()
在编译(例如,用 jax.jit
或 jax.pmap
装饰)函数中将追踪的数组值打印到标准输出。
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯
对于某些转换,如 jax.grad
和 jax.vmap
,您可以使用 Python 内置的 print
函数来打印数值。但 print
不适用于 jax.jit
或 jax.pmap
,因为这些转换会延迟数值求值。因此,请改用 jax.debug.print
!
在语义上,jax.debug.print
大致等同于以下 Python 函数
def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
print(fmt.format(*args, **kwargs))
不同之处在于它可以被 JAX 暂存并转换。有关更多详细信息,请参阅API 参考
。
请注意,fmt
不能是 f-string,因为 f-string 是立即格式化的,而对于 jax.debug.print
,我们希望延迟格式化直到后续。
何时使用“调试”打印?#
您应该在 JAX 转换(如 jit
、vmap
等)中使用 jax.debug.print
来处理动态(即已追踪)的数组值。对于静态值(如数组形状或数据类型)的打印,您可以使用普通的 Python print
语句。
为何使用“调试”打印?#
为了调试,jax.debug.print
可以揭示关于计算如何求值的信息。
xs = jnp.arange(3.)
def f(x):
jax.debug.print("x: {}", x)
y = jnp.sin(x)
jax.debug.print("y: {}", y)
return y
jax.vmap(f)(xs)
# Prints: x: 0.0
# x: 1.0
# x: 2.0
# y: 0.0
# y: 0.841471
# y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
# y: 0.0
# x: 1.0
# y: 0.841471
# x: 2.0
# y: 0.9092974
请注意,打印的结果顺序不同!
通过揭示这些内部工作原理,jax.debug.print
的输出不遵守 JAX 通常的语义保证,例如 jax.vmap(f)(xs)
和 jax.lax.map(f, xs)
计算相同的东西(以不同的方式)。然而,这些求值顺序的细节正是我们在调试时可能想要看到的!
因此,将 jax.debug.print
用于调试,而不是在语义保证很重要的时候使用。
jax.debug.print
的更多示例#
除了上面使用 jit
和 vmap
的示例之外,这里还有一些需要记住的例子。
在 jax.pmap
下打印#
当在 jax.pmap
下时,jax.debug.print
的输出可能会被重新排序!
xs = jnp.arange(2.)
def f(x):
jax.debug.print("x: {}", x)
return x
jax.pmap(f)(xs)
# Prints: x: 0.0
# x: 1.0
# OR
# Prints: x: 1.0
# x: 0.0
在 jax.grad
下打印#
在 jax.grad
下,jax.debug.print
将只在前向传播时打印。
def f(x):
jax.debug.print("x: {}", x)
return x * 2.
jax.grad(f)(1.)
# Prints: x: 1.0
此行为类似于 Python 内置的 print
在 jax.grad
下的工作方式。但通过在此处使用 jax.debug.print
,即使调用者应用了 jax.jit
,行为也相同。
要在反向传播时打印,只需使用 jax.custom_vjp
@jax.custom_vjp
def print_grad(x):
return x
def print_grad_fwd(x):
return x, None
def print_grad_bwd(_, x_grad):
jax.debug.print("x_grad: {}", x_grad)
return (x_grad,)
print_grad.defvjp(print_grad_fwd, print_grad_bwd)
def f(x):
x = print_grad(x)
return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0
在其他转换中打印#
jax.debug.print
也适用于其他转换,例如 pjit
。
使用 jax.debug.callback
进行更多控制#
实际上,jax.debug.print
是 jax.debug.callback
的一个简单便捷包装器,可以直接使用它来更好地控制字符串格式,甚至输出类型。
在语义上,jax.debug.callback
大致等同于以下 Python 函数
def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
fun(*args, **kwargs)
return None
与 jax.debug.print
一样,这些回调函数只能用于调试输出,例如打印或绘图。打印和绘图是相当无害的,但如果将其用于其他任何目的,其行为在转换下可能会让您感到惊讶。例如,使用 jax.debug.callback
进行计时操作是不安全的,因为回调函数可能会被重新排序和异步执行(见下文)。
jax.debug.print
的优点和局限性#
优点#
打印调试简单直观
jax.debug.callback
可用于其他无害的副作用
局限性#
添加打印语句是一个手动过程
可能对性能产生影响
使用 jax.debug.breakpoint()
进行交互式检查#
总结: 使用 jax.debug.breakpoint()
暂停 JAX 程序的执行以检查值
@jax.jit
def f(x):
y, z = jnp.sin(x), jnp.cos(x)
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution!
jax.debug.breakpoint()
实际上只是 jax.debug.callback(...)
的一个应用,它捕获有关调用堆栈的信息。因此,它具有与 jax.debug.print
相同的转换行为(例如,对 jax.debug.breakpoint()
进行 vmap
将其沿着映射轴展开)。
用法#
在编译后的 JAX 函数中调用 jax.debug.breakpoint()
将在命中断点时暂停您的程序。您将看到一个类似 pdb
的提示符,允许您检查调用堆栈中的值。与 pdb
不同,您将无法单步执行,但可以恢复执行。
调试器命令
help
- 打印可用命令p
- 求值表达式并打印其结果pp
- 求值表达式并美观地打印其结果u(p)
- 向上移动一个堆栈帧d(own)
- 向下移动一个堆栈帧w(here)/bt
- 打印回溯信息l(ist)
- 打印代码上下文c(ont(inue))
- 恢复程序执行q(uit)/exit
- 退出程序(在 TPU 上无效)
示例#
与 jax.lax.cond
的用法#
与 jax.lax.cond
结合使用时,调试器可以成为检测 nan
或 inf
的有用工具。
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 0.) # ==> Pauses during execution!
尖锐问题#
因为 jax.debug.breakpoint
只是 jax.debug.callback
的一个应用,所以它具有与 jax.debug.print
相同的尖锐问题,还有一些额外的注意事项。
jax.debug.breakpoint
比jax.debug.print
具体化更多中间值,因为它强制具体化调用堆栈中的所有值。jax.debug.breakpoint
比jax.debug.print
具有更多的运行时开销,因为它可能需要将 JAX 程序中的所有中间值从设备复制到主机。
jax.debug.breakpoint()
的优点和局限性#
优点#
简单、直观且(某种程度上)标准
可以同时检查调用堆栈中上下许多值
局限性#
可能需要使用许多断点才能找出错误的根源
具体化许多中间值