调试运行时值#

您是否遇到了梯度爆炸? NaN 是否让您烦恼不已? 只是想查看计算过程中的中间值? 看看下面这些 JAX 调试工具!本页包含摘要,您可以点击底部的“阅读更多”链接了解更多信息。

目录

使用 jax.debug 进行交互式检查#

完整指南 在此处

摘要: 使用 jax.debug.print()jax.jit-、jax.pmap- 和 pjit- 装饰的函数中将值打印到标准输出,并使用 jax.debug.breakpoint() 暂停编译函数的执行以检查调用堆栈中的值

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  jax.debug.print("🤯 {x} 🤯", x=x)
  y = jnp.sin(x)
  jax.debug.breakpoint()
  jax.debug.print("🤯 {y} 🤯", y=y)
  return y

f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯

阅读更多.

使用 jax.experimental.checkify 进行函数式错误检查#

完整指南 在此处

摘要: Checkify 允许您向 JAX 代码添加 jit-able 的运行时错误检查(例如,越界索引)。使用 checkify.checkify 转换以及类似 assert 的 checkify.check 函数来为 JAX 代码添加运行时检查。

from jax.experimental import checkify
import jax
import jax.numpy as jnp

def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  y = x[i]
  z = jnp.sin(y)
  return z

jittable_f = checkify.checkify(f)

err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))

您也可以使用 checkify 自动添加常见的检查。

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)

err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))

err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)

阅读更多.

使用 JAX 的调试标志抛出 Python 错误#

完整指南 在此处

摘要: 启用 jax_debug_nans 标志以自动检测 jax.jit 编译的代码中何时产生 NaN(但在 jax.pmapjax.pjit 编译的代码中不会),并启用 jax_disable_jit 标志以禁用 JIT 编译,从而可以使用传统的 Python 调试工具,如 printpdb

import jax
jax.config.update("jax_debug_nans", True)

def f(x, y):
  return x / y
jax.jit(f)(0., 0.)  # ==> raises FloatingPointError exception!

阅读更多.

使用 set_xla_metadata 附加 XLA 元数据#

完整指南 在此处

摘要: set_xla_metadata 允许您将元数据附加到 JAX 代码中的操作。这些元数据将被作为 frontend_attributes 传递给 XLA 编译器,并可用于启用编译器级别的调试工具,例如 XLA-TPU 调试器。

注意: set_xla_metadata 是一项实验性功能,其 API 可能会发生变化。

import jax
import jax.numpy as jnp
from jax.experimental.xla_metadata import set_xla_metadata

# Tagging an individual operation
def value_tagging(x):
  y = jnp.sin(x)
  z = jnp.cos(x)
  return set_xla_metadata(y * z, breakpoint=True)

print(jax.jit(value_tagging).lower(1.0).as_text("hlo"))

结果

ENTRY main.5 {
  x.1 = f32[] parameter(0)
  sin.2 = f32[] sine(x.1)
  cos.3 = f32[] cosine(x.1)
  ROOT mul.4 = f32[] multiply(sin.2, cos.3), frontend_attributes={breakpoint="true"}
}

阅读更多.