外部回调#
本教程概述了如何使用各种回调函数,这些函数允许 JAX 运行时在主机上执行 Python 代码。JAX 回调的示例如 jax.pure_callback、jax.experimental.io_callback 和 jax.debug.callback。您可以在 JAX 转换(包括 jit()、vmap()、grad())下运行时使用它们。
为什么要使用回调?#
回调例程是在运行时执行代码的**主机端**执行方式。举一个简单的例子,假设您想在计算过程中打印某个变量的*值*。使用简单的 Python print() 语句,看起来是这样的:
import jax
@jax.jit
def f(x):
y = x + 1
print("intermediate value: {}".format(y))
return y * 2
result = f(2)
intermediate value: JitTracer<~int32[]>
打印的不是运行时值,而是跟踪时期的抽象值(如果您不熟悉 JAX 中的*跟踪*,可以在 跟踪 中找到很好的入门指南)。
要打印运行时值,您需要一个回调,例如 jax.debug.print()(您可以在 调试简介 中了解更多关于调试的信息)。
@jax.jit
def f(x):
y = x + 1
jax.debug.print("intermediate value: {}", y)
return y * 2
result = f(2)
intermediate value: 3
这通过将 y 的运行时值作为 CPU jax.Array 传递回主机进程来实现,主机进程可以打印它。
回调的种类#
在 JAX 的早期版本中,只有一种回调可用,实现于 jax.experimental.host_callback()。 host_callback 例程存在一些缺陷,现在已被弃用,取而代之的是为不同情况设计的几种回调:
jax.pure_callback():适用于纯函数;即没有副作用的函数。请参阅 探索 pure_callback。jax.experimental.io_callback():适用于不纯函数;例如,读写磁盘数据的函数。请参阅 探索 io_callback。jax.debug.callback():适用于应反映编译器执行行为的函数。请参阅 探索 debug.callback。
(您之前使用的 jax.debug.print() 函数是 jax.debug.callback() 的包装器。)
从用户角度来看,这三种回调主要区别在于它们允许的转换和编译器优化。
回调函数 |
支持返回值 |
|
|
|
|
保证执行 |
|---|---|---|---|---|---|---|
✅ |
✅ |
✅ |
❌¹ |
✅ |
❌ |
|
✅ |
✅ |
✅/❌² |
❌ |
✅³ |
✅ |
|
❌ |
✅ |
✅ |
✅ |
✅ |
❌ |
¹ jax.pure_callback 可以与 custom_jvp 一起使用,使其与自动微分兼容。
² jax.experimental.io_callback 仅当 ordered=False 时才与 vmap 兼容。
³ 请注意,io_callback 的 scan/while_loop 的 vmap 具有复杂的语义,其行为在未来版本中可能会发生变化。
探索 pure_callback#
jax.pure_callback() 通常是您在需要主机端执行纯函数时应该使用的回调函数;即没有副作用的函数(例如打印值、从磁盘读取数据、更新全局状态等)。
传递给 jax.pure_callback() 的函数不一定必须是纯函数,但 JAX 的转换和高阶函数会将其视为纯函数,这意味着它可能会被静默删除或多次调用。
import jax
import jax.numpy as jnp
import numpy as np
def f_host(x):
# call a numpy (not jax.numpy) operation:
return np.sin(x).astype(x.dtype)
def f(x):
result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
return jax.pure_callback(f_host, result_shape, x, vmap_method='sequential')
x = jnp.arange(5.0)
f(x)
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
由于 pure_callback 可以被删除或复制,因此它与 jit 等转换以及 scan 和 while_loop 等高阶原语开箱即用。
jax.jit(f)(x)
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
def body_fun(_, x):
return _, f(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
由于我们在 pure_callback 函数调用中指定了 vmap_method,因此它也将与 vmap 兼容。
jax.vmap(f)(x)
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
然而,由于 JAX 无法内省回调的内容,因此 pure_callback 具有未定义的自动微分语义。
jax.grad(f)(x)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.
有关使用 pure_callback 与 jax.custom_jvp() 的示例,请参阅下面的*示例:pure_callback 与 custom_jvp*。
根据设计,传递给 pure_callback 的函数被视为没有副作用:一个后果是,如果函数输出未使用,编译器可能会完全消除回调。
def print_something():
print('printing something')
return np.int32(0)
@jax.jit
def f1():
return jax.pure_callback(print_something, np.int32(0))
f1();
printing something
@jax.jit
def f2():
jax.pure_callback(print_something, np.int32(0))
return 1.0
f2();
在 f1 中,回调的输出在函数的返回值中使用,因此回调被执行,我们看到打印的输出。而在 f2 中,回调的输出未使用,因此编译器注意到这一点并消除了函数调用。这些是回调到无副作用函数的正确语义。
pure_callback 与异常#
在 JAX 转换的上下文中,Python 运行时异常应被视为副作用:这意味着在 pure_callback 中故意引发错误会打破 API 合约,并且由此产生的程序的行为是未定义的。特别是,此类程序停止的方式通常取决于后端,并且该行为的详细信息在未来版本中可能会发生变化。
此外,将不纯函数传递给 pure_callback 可能会导致 jax.jit() 或 jax.vmap() 等转换过程中出现意外行为,因为 pure_callback 的转换规则是在回调函数为纯函数的假设下定义的。这是一个不纯回调在 vmap 下行为意外的简单示例:
import jax
import jax.numpy as jnp
def raise_via_callback(x):
def _raise(x):
raise ValueError(f"value of x is {x}")
return jax.pure_callback(_raise, x, x)
def raise_if_negative(x):
return jax.lax.cond(x < 0, raise_via_callback, lambda x: x, x)
x_batch = jnp.arange(4)
[raise_if_negative(x) for x in x_batch] # does not raise
jax.vmap(raise_if_negative)(x_batch) # ValueError: value of x is 0
为避免此及类似的意外行为,我们建议不要尝试使用 pure_callback 来引发运行时错误。
探索 io_callback#
与 jax.pure_callback() 相比,jax.experimental.io_callback() 明确用于不纯函数,即具有副作用的函数。
例如,这是一个调用全局主机端 numpy 随机生成器的回调。这是一个不纯操作,因为在 numpy 中生成随机数的副作用是更新随机状态(请注意,这仅是为了作为 io_callback 的玩具示例,并不一定是 JAX 中生成随机数的推荐方法!)。
from jax.experimental import io_callback
from functools import partial
global_rng = np.random.default_rng(0)
def host_side_random_like(x):
"""Generate a random array like x using the global_rng state"""
# We have two side-effects here:
# - printing the shape and dtype
# - calling global_rng, thus updating its state
print(f'generating {x.dtype}{list(x.shape)}')
return global_rng.uniform(size=x.shape).astype(x.dtype)
@jax.jit
def numpy_random_like(x):
return io_callback(host_side_random_like, x, x)
x = jnp.zeros(5)
numpy_random_like(x)
generating float32[5]
Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ], dtype=float32)
默认情况下,io_callback 与 vmap 兼容。
jax.vmap(numpy_random_like)(x)
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.91275555, 0.60663575, 0.72949654, 0.543625 , 0.9350724 ], dtype=float32)
但请注意,这可能会以任何顺序执行映射的回调。因此,例如,如果您在 GPU 上运行此代码,则映射输出的顺序可能因运行而异。
如果回调顺序的保留很重要,您可以设置 ordered=True,在这种情况下,尝试 vmap 将引发错误。
@jax.jit
def numpy_random_like_ordered(x):
return io_callback(host_side_random_like, x, x, ordered=True)
jax.vmap(numpy_random_like_ordered)(x)
ValueError: Cannot `vmap` ordered IO callback.
另一方面,scan 和 while_loop 与 io_callback 配合使用,而不管是否强制执行排序。
def body_fun(_, x):
return _, numpy_random_like_ordered(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544], dtype=float32)
与 pure_callback 一样,如果将 io_callback 传递给已微分的变量,则它在自动微分下会失败。
jax.grad(numpy_random_like)(x)
ValueError: IO callbacks do not support JVP.
但是,如果回调不依赖于已微分的变量,它将执行。
@jax.jit
def f(x):
io_callback(lambda: print('hello'), None)
return x
jax.grad(f)(1.0);
hello
与 pure_callback 不同,即使回调的输出在后续计算中未使用,编译器也不会在此情况下删除回调的执行。
探索 debug.callback#
无论是 pure_callback 还是 io_callback,它们都对所调用函数的纯度强制执行一些假设,并在各种方面限制 JAX 转换和编译机制可能执行的操作。 debug.callback 基本上*不*假设任何关于回调函数的行为,因此回调的操作精确地反映了 JAX 在程序过程中所做的事情。此外,debug.callback *无法*将任何值返回给程序。
from jax import debug
def log_value(x):
# This could be an actual logging call; we'll use
# print() for demonstration
print("log:", x)
@jax.jit
def f(x):
debug.callback(log_value, x)
return x
f(1.0);
log: 1.0
调试回调与 vmap 兼容。
x = jnp.arange(5.0)
jax.vmap(f)(x);
log: 0.0
log: 1.0
log: 2.0
log: 3.0
log: 4.0
也与 grad 和其他自动微分转换兼容。
jax.grad(f)(1.0);
log: 1.0
这使得 debug.callback 在通用调试方面比 pure_callback 或 io_callback 更有用。
示例:pure_callback 与 custom_jvp#
利用 jax.pure_callback() 的一个强大方法是将其与 jax.custom_jvp 结合使用。(有关 jax.custom_jvp() 的更多详细信息,请参阅 JAX 可转换 Python 函数的自定义导数规则)。
假设您想为尚未在 jax.scipy 或 jax.numpy 包装器中提供的 scipy 或 numpy 函数创建 JAX 兼容的包装器。
在这里,我们将考虑为第一类贝塞尔函数创建一个包装器,该函数可在 scipy.special.jv 中找到。您可以首先定义一个简单的 pure_callback()。
import jax
import jax.numpy as jnp
import scipy.special
def jv(v, z):
v, z = jnp.asarray(v), jnp.asarray(z)
# Require the order v to be integer type: this simplifies
# the JVP rule below.
assert jnp.issubdtype(v.dtype, jnp.integer)
# Promote the input to inexact (float/complex).
# Note that jnp.result_type() accounts for the enable_x64 flag.
z = z.astype(jnp.result_type(float, z.dtype))
# Wrap scipy function to return the expected dtype.
_scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)
# Define the expected shape & dtype of output.
result_shape_dtype = jax.ShapeDtypeStruct(
shape=jnp.broadcast_shapes(v.shape, z.shape),
dtype=z.dtype)
# Use vmap_method="broadcast_all" because scipy.special.jv handles broadcasted inputs.
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vmap_method="broadcast_all")
这使我们能够从转换后的 JAX 代码(包括通过 jit() 和 vmap() 转换时)调用 scipy.special.jv()。
from functools import partial
j1 = partial(jv, 1)
z = jnp.arange(5.0)
print(j1(z))
[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]
这是使用 jit() 的相同结果。
print(jax.jit(j1)(z))
[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]
这是使用 vmap() 的相同结果。
print(jax.vmap(j1)(z))
[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]
但是,如果您调用 grad(),您将收到错误,因为此函数没有定义自动微分规则。
jax.grad(j1)(z)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.
让我们为此定义一个自定义梯度规则。查看*第一类贝塞尔函数*的定义,您会发现关于自变量 z 的导数有一个相对简单的递推关系:
关于 \(\nu\) 的梯度更复杂,但由于我们将 v 参数限制为整数类型,因此在此示例中您无需担心其梯度。
您可以使用 jax.custom_jvp() 为回调函数定义此自动微分规则:
jv = jax.custom_jvp(jv)
@jv.defjvp
def _jv_jvp(primals, tangents):
v, z = primals
_, z_dot = tangents # Note: v_dot is always 0 because v is integer.
jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
return jv(v, z), z_dot * djv_dz
现在计算函数的梯度将正常工作。
j1 = partial(jv, 1)
print(jax.grad(j1)(2.0))
-0.06447162
此外,由于我们已经根据 jv 本身定义了梯度,JAX 的架构意味着您可以免费获得二阶及更高阶导数。
jax.hessian(j1)(2.0)
Array(-0.4003078, dtype=float32, weak_type=True)
请记住,虽然这一切与 JAX 配合正常工作,但每次调用基于回调的 jv 函数时,都会将输入数据从设备传递到主机,并将 scipy.special.jv() 的输出从主机传递回设备。
在 GPU 或 TPU 等加速器上运行时,这种数据传输和主机同步可能会在每次调用 jv 时产生显著的开销。
但是,如果您在单个 CPU 上运行 JAX(其中“主机”和“设备”位于同一硬件上),JAX 通常会以快速的零拷贝方式进行此数据传输,从而使此模式成为扩展 JAX 功能的一种相对直接的方式。