checkify 转换#

摘要:checkify 允许您为 JAX 代码添加可 `jit` 编译的运行时错误检查(例如,越界索引)。结合使用 `checkify.checkify` 转换和类似断言的 `checkify.check` 函数,可为 JAX 代码添加运行时检查。

from jax.experimental import checkify
import jax
import jax.numpy as jnp

def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i)
  y = x[i]
  z = jnp.sin(y)
  return z

jittable_f = checkify.checkify(f)

err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
print(err.get())
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))

您也可以使用 checkify 自动添加常见检查

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)

err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))

err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)

err, z = checked_f(jnp.array([5, 1]), 0)
err.throw()  # if no error occurred, throw does nothing!

函数化检查#

类似断言的检查 API 本身不是函数式纯净的:它会像断言一样,作为副作用引发 Python 异常。因此,它不能通过 `jit`、`pmap`、`pjit` 或 `scan` 进行阶段化。

jax.jit(f)(jnp.ones((5,)), -1)  # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.

但 checkify 转换会将这些副作用函数化(或解除)。经 checkify 转换的函数会以新的输出形式返回一个错误,并保持函数式纯净。这种函数化意味着 checkify 转换的函数可以根据需要与阶段化/转换组合使用

err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
..  at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
..  at mapped index 2: out-of-bounds indexing at <..>:7 (f)
"""

为什么 JAX 需要 checkify?#

在某些 JAX 转换下,您可以使用普通的 Python 断言来表达运行时错误检查,例如仅使用 `jax.grad` 和 `jax.numpy` 时

def f(x):
  assert x > 0., "must be positive!"
  return jnp.log(x)

jax.grad(f)(0.)
# ValueError: "must be positive!"

但普通断言在 `jit`、`pmap`、`pjit` 或 `scan` 内部不起作用。在这些情况下,数值计算会被阶段化,而不是在 Python 执行期间急切地评估,因此数值不可用

jax.jit(f)(0.)
# ConcretizationTypeError: "Abstract tracer value encountered ..."

JAX 转换语义依赖于函数式纯净性,尤其是在组合多个转换时,那么我们如何在不破坏这一切的情况下提供错误机制呢?除了需要一个新的 API,情况还更加棘手:XLA HLO 不支持断言或抛出错误,因此即使我们有一个能够阶段化断言的 JAX API,我们又将如何把这些断言下沉到 XLA 呢?

您可以想象手动为函数添加运行时检查,并导出表示错误的数值

def f_checked(x):
  error = x <= 0.
  result = jnp.log(x)
  return error, result

err, y = jax.jit(f_checked)(0.)
if err:
  raise ValueError("must be positive!")
# ValueError: "must be positive!"

错误是函数计算出的一个常规值,并且错误在 `f_checked` 外部被引发。`f_checked` 是函数式纯净的,因此我们知道,通过其构造,它将已经与 `jit`、pmap、pjit、scan 以及所有 JAX 转换兼容。唯一的问题是这种数据传递可能会很麻烦!

`checkify` 会为您完成此重写:这包括通过函数传递错误值,将检查重写为布尔运算并将结果与跟踪的错误值合并,以及将最终错误值作为 checkified 函数的输出返回

def f(x):
  checkify.check(x > 0., "{} must be positive!", x)  # convenient but effectful API
  return jnp.log(x)

f_checked = checkify(f)

err, x = jax.jit(f_checked)(-1.)
err.throw()
# ValueError: -1. must be positive! (check failed at <...>:2 (f))

我们称之为函数化或解除调用 check 所引入的副作用。(在上面的“手动”示例中,错误值只是一个布尔值。checkify 的错误值在概念上相似,但也会跟踪错误消息并公开 throw 和 get 方法;参见 jax.experimental.checkify)。`checkify.check` 还允许您通过将运行时值作为格式参数提供给错误消息来向其添加运行时值。

您现在可以手动为代码添加运行时检查,但 `checkify` 也可以自动添加常见错误的检查!请考虑以下错误情况:

jnp.arange(3)[5]                # out of bounds
jnp.sin(jnp.inf)                # NaN generated
jnp.ones((5,)) / jnp.arange(5)  # division by zero

默认情况下,`checkify` 只解除 `checkify.check` 的副作用,并且不会采取任何措施来捕获上述错误。但是,如果您要求它这样做,`checkify` 也会自动为您的代码添加检查。

def f(x, i):
  y = x[i]        # i could be out of bounds.
  z = jnp.sin(y)  # z could become NaN
  return z

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)

err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)

选择启用哪些自动检查的 API 基于集合(Sets)。有关更多详细信息,请参阅 jax.experimental.checkify

JAX 转换下的 `checkify`。#

如上例所示,checkified 函数可以愉快地 jit 编译。以下是 `checkify` 与其他 JAX 转换的一些示例。请注意,checkified 函数是函数式纯净的,并且应该可以轻松地与所有 JAX 转换组合!

`jit`#

您可以安全地将 `jax.jit` 添加到 checkified 函数,或者对已 jit 编译的函数进行 `checkify` 转换,两者都将有效。

def f(x, i):
  return x[i]

checkify_of_jit = checkify.checkify(jax.jit(f))
jit_of_checkify = jax.jit(checkify.checkify(f))
err, _ =  checkify_of_jit(jnp.ones((5,)), 100)
err.get()
# out-of-bounds indexing at <..>:2 (f)
err, _ = jit_of_checkify(jnp.ones((5,)), 100)
# out-of-bounds indexing at <..>:2 (f)

`vmap`/pmap#

您可以对 checkified 函数进行 `vmap` 和 `pmap` 转换(或对已映射的函数进行 `checkify` 转换)。映射 checkified 函数将为您提供一个映射错误,该错误可以包含映射维度中每个元素的不同错误。

def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  return x[i]

checked_f = checkify.checkify(f, errors=checkify.all_checks)
errs, out = jax.vmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
errs.throw()
"""
ValueError:
  at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f))
  at mapped index 2: out-of-bounds indexing at <...>:3 (f)
"""

然而,对 vmap 进行 checkify 转换将生成一个单一的(未映射的)错误!

@jax.vmap
def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  return x[i]

checked_f = checkify.checkify(f, errors=checkify.all_checks)
err, out = checked_f(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
# ValueError: index needs to be non-negative! (check failed at <...>:2 (f))

`pjit`#

对 checkified 函数进行 `pjit` 直接有效,您只需为错误值输出指定一个额外的 `out_axis_resources` 为 `None`。

def f(x):
  return x / x

f = checkify.checkify(f, errors=checkify.float_checks)
f = pjit(
  f,
  in_shardings=PartitionSpec('x', None),
  out_shardings=(None, PartitionSpec('x', None)))

with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
 err, data = f(input_data)
err.throw()
# ValueError: divided by zero at <...>:4 (f)

`grad`#

如果您对 grad 进行 checkify 转换,您的梯度计算也将被检测

def f(x):
 return x / (1 + jnp.sqrt(x))

grad_f = jax.grad(f)

err, _ = checkify.checkify(grad_f, errors=checkify.nan_checks)(0.)
print(err.get())
>> nan generated by primitive mul at <...>:3 (f)

请注意,`f` 中没有乘法,但其梯度计算中有乘法(此处生成了 NaN!)。因此,使用 checkify-of-grad 可以为前向和后向传播操作添加自动检查。

`checkify.check` 将只应用于函数的主值。如果您想在梯度值上使用 `check`,请使用 `custom_vjp`。

@jax.custom_vjp
def assert_gradient_negative(x):
 return x

def fwd(x):
 return assert_gradient_negative(x), None

def bwd(_, grad):
 checkify.check(grad < 0, "gradient needs to be negative!")
 return (grad,)

assert_gradient_negative.defvjp(fwd, bwd)

jax.grad(assert_gradient_negative)(-1.)
# ValueError: gradient needs to be negative!

`jax.experimental.checkify` 的优点与局限性#

优点#

  • 您可以随处使用它(错误“只是值”,并且在转换下像其他值一样直观地表现)

  • 自动检测:您无需对代码进行局部修改。相反,`checkify` 可以对全部代码进行检测!

局限性#

  • 添加大量运行时检查可能会很昂贵(例如,为每个原语添加 NaN 检查会增加大量计算操作)

  • 需要将错误值从函数中传递出来并手动抛出错误。如果未明确抛出错误,您可能会错过错误!

  • 抛出错误值会在主机上具体化该错误值,这意味着它是一个阻塞操作,会破坏 JAX 的异步预执行。