jax.pure_callback#
- jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=Deprecated, vmap_method=None, **kwargs)[source]#
调用纯 Python 回调函数。在
jit()
/vmap()
等函数下有效。欲了解更多说明,请参阅外部回调。
pure_callback
允许在 JIT 编译的 JAX 函数中调用 Python 函数。输入callback
将传递放置在本地 CPU 上的 JAX 数组,并且它也应该返回 CPU 上的 JAX 数组。该回调函数被视为函数式纯净的,这意味着它没有副作用,并且其输出值仅取决于其参数值。因此,它可以安全地被多次调用(例如当通过
vmap()
或pmap()
转换时),或者在例如 jit 修饰的函数的输出对其值没有数据依赖时完全不被调用。如果数据依赖允许,纯回调函数也可以被重新排序。警告
在 JAX 转换的上下文中,Python 异常应被视为副作用:这意味着在 pure_callback 中有意引发错误会破坏 API 约定,并且由此产生的程序的行为是未定义的。
当进行 vmap 转换时,其行为将取决于
vmap_method
的值。在没有显式
vmap_method
的回调函数上调用vmap()
将引发NotImplementedError
。vmap_method="sequential"
使用map()
循环遍历批处理参数,为每个批处理元素调用一次callback
。vmap_method="sequential_unrolled"
类似于sequential
,但循环是展开的。vmap_method="expand_dims"
调用callback
,其中大小为1
的新轴作为非批处理输入的主维度添加。vmap_method="broadcast_all"
的行为类似于expand_dims
,但输入会被平铺到预期的批处理形状。
如有必要,可以通过使用
vmap_method="legacy_vectorized"
恢复已移除的vectorized=True
参数所提供的旧版行为。当前默认行为是在未指定时使用
vmap_method="sequential"
,但此行为已弃用,将来,除非显式指定vmap_method
,否则默认将引发NotImplementedError
。- 参数:
callback (Callable[..., Any]) – 在主机上执行的函数。该回调函数被假定为纯函数(即没有副作用的函数):如果传递非纯函数,它可能会以意想不到的方式运行,尤其是在转换下。可调用对象将以数组的 PyTree 作为参数传入,并应返回一个与
result_shape_dtypes
匹配的数组 PyTree。result_shape_dtypes (Any) – 一个 PyTree,其叶子具有
shape
和dtype
属性,其结构与回调函数在运行时期的预期输出相匹配。jax.ShapeDtypeStruct
经常用于定义叶子值。*args (Any) – 将传递给回调函数的参数
sharding (SingleDeviceSharding | None) – 可选的分片,指定应从中调用回调函数的设备。
**kwargs (Any) – 将传递给回调函数的关键字参数
vectorized (bool | None | DeprecatedArg)
- 返回:
- 一个由
jax.Array
对象组成的 PyTree,其结构与 result_shape_dtypes
.
- 一个由
- 返回类型:
结果
另请参阅
jax.experimental.io_callback()
: 为非纯函数设计的回调函数。jax.debug.callback()
: 为通用调试设计的回调函数。jax.debug.print()
: 为打印设计的通用回调函数。
示例
pure_callback
在vmap()
下的行为由vmap_method
参数控制,如上所述。考虑一些明确的示例来演示其语义是很有用的。例如,考虑以下函数>>> def callback(x, y): ... print(jnp.shape(x), jnp.shape(y)) ... return x + y
>>> def fun(x, y, *, vmap_method): ... shape = jnp.broadcast_shapes(jnp.shape(x), jnp.shape(y)) ... dtype = jnp.result_type(x, y) ... out_type = jax.ShapeDtypeStruct(shape, dtype) ... return jax.pure_callback(callback, out_type, x, y, ... vmap_method=vmap_method)
使用
vmap_method="expand_dims"
调用它会在y
中添加一个大小为1
的新轴>>> from functools import partial >>> x = jnp.arange(4) >>> y = 1.0 >>> jax.vmap(partial(fun, vmap_method="expand_dims"), in_axes=(0, None))(x, y) (4,) (1,) Array([1., 2., 3., 4.], dtype=float32)
而
vmap_method="broadcast_all"
会在y
中添加一个大小为4
的轴>>> jax.vmap(partial(fun, vmap_method="broadcast_all"), ... in_axes=(0, None))(x, y) (4,) (4,) Array([1., 2., 3., 4.], dtype=float32)