The checkify transformation#
摘要: Checkify 允许您向 JAX 代码添加 jit-able 的运行时错误检查(例如,越界索引)。使用 checkify.checkify 转换以及类似 assert 的 checkify.check 函数来为 JAX 代码添加运行时检查。
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i)
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
print(err.get())
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))
您也可以使用 checkify 自动添加常见的检查。
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
err, z = checked_f(jnp.array([5, 1]), 0)
err.throw() # if no error occurred, throw does nothing!
Functionalizing checks#
assert 风格的检查 API 本身不是纯函数的:它可以像 assert 一样产生副作用,即引发 Python 异常。因此,它不能被 jit、pmap、pjit 或 scan 进行阶段化处理。
jax.jit(f)(jnp.ones((5,)), -1) # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.
但是 checkify 转换可以将这些副作用函数化(或称为“解除”)。经过 checkify 转换的函数会返回一个错误 *值* 作为新的输出,并保持纯函数性。这种函数化意味着经过 checkify 转换的函数可以与任何我们想要的阶段化/转换进行组合。
err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
.. at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
.. at mapped index 2: out-of-bounds indexing at <..>:7 (f)
"""
Why does JAX need checkify?#
在某些 JAX 转换下,您可以使用普通的 Python 断言来表达运行时错误检查,例如,仅使用 jax.grad 和 jax.numpy 时。
def f(x):
assert x > 0., "must be positive!"
return jnp.log(x)
jax.grad(f)(0.)
# ValueError: "must be positive!"
但是,普通的断言在 jit、pmap、pjit 或 scan 中不起作用。因此,数值计算会被阶段化处理而不是在 Python 执行期间即时评估,结果就是数值不可用。
jax.jit(f)(0.)
# ConcretizationTypeError: "Abstract tracer value encountered ..."
JAX 的转换语义依赖于函数式纯粹性,尤其是在组合多个转换时。那么,如何在不破坏这一切的情况下提供一个错误机制呢?除了需要新的 API,情况还更加棘手:XLA HLO 不支持断言或抛出错误。因此,即使我们有一个能够阶段化处理断言的 JAX API,如何将这些断言降低到 XLA 呢?
您可以设想手动将运行时检查添加到您的函数中,并将表示错误的变量传递出去。
def f_checked(x):
error = x <= 0.
result = jnp.log(x)
return error, result
err, y = jax.jit(f_checked)(0.)
if err:
raise ValueError("must be positive!")
# ValueError: "must be positive!"
错误是函数计算的一个常规值,并且错误是在 f_checked 之外引发的。f_checked 是纯函数,因此我们从构造上就知道它将与 jit、pmap、pjit、scan 以及 JAX 的所有转换一起工作。唯一的问题是这种传递可能会很麻烦!
checkify 会为您执行此重写:包括将错误值传递到函数中,将检查重写为布尔运算并将结果与跟踪的错误值合并,并将最终错误值作为输出返回给经过 checkify 转换的函数。
def f(x):
checkify.check(x > 0., "{} must be positive!", x) # convenient but effectful API
return jnp.log(x)
f_checked = checkify(f)
err, x = jax.jit(f_checked)(-1.)
err.throw()
# ValueError: -1. must be positive! (check failed at <...>:2 (f))
我们将这种行为称为函数化或解除(discharging)调用检查所引入的副作用。(在上面的“手动”示例中,错误值只是一个布尔值。checkify 的错误值在概念上是相似的,但它们还会跟踪错误消息并暴露 throw 和 get 方法;请参阅 jax.experimental.checkify)。checkify.check 还允许您通过提供格式参数来将运行时值添加到错误消息中。
您现在可以手动为代码添加运行时检查,但 checkify 也可以自动为常见错误添加检查!请考虑以下错误情况:
jnp.arange(3)[5] # out of bounds
jnp.sin(jnp.inf) # NaN generated
jnp.ones((5,)) / jnp.arange(5) # division by zero
默认情况下,checkify 只解除 checkify.check,而不会采取任何措施来捕获上述错误。但是,如果您要求它这样做,checkify 也会自动为您的代码添加检查。
def f(x, i):
y = x[i] # i could be out of bounds.
z = jnp.sin(y) # z could become NaN
return z
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
用于选择要启用哪些自动检查的 API 基于集合(Sets)。有关更多详细信息,请参阅 jax.experimental.checkify。
checkify under JAX transformations.#
如上面的示例所示,经过 checkify 转换的函数可以很好地进行 jit 编译。以下是 checkify 与其他 JAX 转换的一些额外示例。请注意,经过 checkify 转换的函数是纯函数,并且应可轻松地与所有 JAX 转换进行组合!
jit#
您可以安全地将 jax.jit 添加到经过 checkify 转换的函数中,或者对 jit 编译的函数进行 checkify 转换,这两种方式都有效。
def f(x, i):
return x[i]
checkify_of_jit = checkify.checkify(jax.jit(f))
jit_of_checkify = jax.jit(checkify.checkify(f))
err, _ = checkify_of_jit(jnp.ones((5,)), 100)
err.get()
# out-of-bounds indexing at <..>:2 (f)
err, _ = jit_of_checkify(jnp.ones((5,)), 100)
# out-of-bounds indexing at <..>:2 (f)
vmap/pmap#
您可以对经过 checkify 转换的函数进行 vmap 和 pmap 转换(或对映射的函数进行 checkify 转换)。映射一个经过 checkify 转换的函数会产生一个映射的错误,该错误可以包含映射维度每个元素的不同错误。
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
return x[i]
checked_f = checkify.checkify(f, errors=checkify.all_checks)
errs, out = jax.vmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
errs.throw()
"""
ValueError:
at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f))
at mapped index 2: out-of-bounds indexing at <...>:3 (f)
"""
但是,对 vmap 进行 checkify 转换会产生一个单一的(未映射的)错误!
@jax.vmap
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
return x[i]
checked_f = checkify.checkify(f, errors=checkify.all_checks)
err, out = checked_f(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
# ValueError: index needs to be non-negative! (check failed at <...>:2 (f))
pjit#
对经过 checkify 转换的函数进行 pjit 转换*就能正常工作*,您只需要为错误值输出指定一个额外的 out_axis_resources 为 None。
def f(x):
return x / x
f = checkify.checkify(f, errors=checkify.float_checks)
f = pjit(
f,
in_shardings=PartitionSpec('x', None),
out_shardings=(None, PartitionSpec('x', None)))
with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
err, data = f(input_data)
err.throw()
# ValueError: divided by zero at <...>:4 (f)
grad#
如果您对 grad 进行 checkify 转换,您的梯度计算也将被插桩。
def f(x):
return x / (1 + jnp.sqrt(x))
grad_f = jax.grad(f)
err, _ = checkify.checkify(grad_f, errors=checkify.nan_checks)(0.)
print(err.get())
>> nan generated by primitive mul at <...>:3 (f)
请注意,在 f 中没有乘法运算,但在其梯度计算中有一个乘法运算(而这正是 NaN 生成的地方!)。因此,使用 checkify-of-grad 为前向和后向传递操作添加自动检查。
checkify.checks 只会应用于函数的原始值。如果您想对梯度值使用 check,请使用 custom_vjp。
@jax.custom_vjp
def assert_gradient_negative(x):
return x
def fwd(x):
return assert_gradient_negative(x), None
def bwd(_, grad):
checkify.check(grad < 0, "gradient needs to be negative!")
return (grad,)
assert_gradient_negative.defvjp(fwd, bwd)
jax.grad(assert_gradient_negative)(-1.)
# ValueError: gradient needs to be negative!
Strengths and limitations of jax.experimental.checkify#
优点#
您可以在任何地方使用它(错误是“普通值”,并且在转换下像其他值一样直观地表现)
自动插桩:您无需对代码进行局部修改。相反,
checkify可以对所有代码进行插桩!
局限性#
添加大量运行时检查可能会很昂贵(例如,对每个原语添加 NaN 检查会为您的计算增加大量操作)。
需要将错误值从函数中传递出来,并手动抛出错误。如果错误没有被显式抛出,您可能会错过错误!
抛出错误值会将该错误值具体化到主机上,这意味着它是一个阻塞操作,这会破坏 JAX 的异步预运行。