jax.pure_callback#

jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=Deprecated, vmap_method=None, **kwargs)[source]#

调用纯 Python 回调。在 jit()/vmap() 等下工作。

更多解释,请参阅 外部回调

pure_callback 允许在 JIT 编译的 JAX 函数中调用 Python 函数。输入 callback 将被传递放置在本地 CPU 上的 JAX 数组,并且它也应该返回 CPU 上的 JAX 数组。

回调被视为功能纯粹的,意味着它没有副作用,并且其输出值仅取决于其参数值。因此,多次调用它是安全的(例如,当被 vmap()pmap() 转换时),或者当例如 jit-装饰函数的输出对其值没有数据依赖性时,根本不被调用。如果数据依赖性允许,纯回调也可能被重新排序。

警告

在 JAX 转换的上下文中,Python 异常应被视为副作用:这意味着在 pure_callback 中有意引发错误会破坏 API 契约,并且结果程序的行为是未定义的。

vmap-ed 时,行为将取决于 vmap_method 的值。

  • 在没有显式 vmap_method 的回调上调用 vmap() 会引发 NotImplementedError

  • vmap_method="sequential" 使用 map() 循环遍历批处理参数,为每个批处理元素调用一次 callback

  • vmap_method="sequential_unrolled" 类似于 sequential,但循环是展开的。

  • vmap_method="expand_dims" 调用 callback,并将大小为 1 的新轴作为前导维度添加到未批处理的输入中。

  • vmap_method="broadcast_all" 的行为类似于 expand_dims,但输入被平铺到预期的批处理形状。

如有必要,可以使用 vmap_method="legacy_vectorized" 恢复已移除的 vectorized=True 参数提供的旧版行为。

当前默认行为是在未指定时使用 vmap_method="sequential",但此行为已弃用,并且在未来,除非显式指定 vmap_method,否则默认行为将是引发 NotImplementedError

参数:
  • callback (Callable[..., Any]) – 在主机上执行的函数。回调被假定为纯函数(即,没有副作用的函数):如果传递了不纯函数,则它可能会以意外的方式运行,尤其是在转换下。可调用对象将被传递数组的 PyTrees 作为参数,并且应返回与 result_shape_dtypes 匹配的数组的 PyTree。

  • result_shape_dtypes (Any) – pytree,其叶子具有 shapedtype 属性,其结构与运行时回调函数的预期输出匹配。jax.ShapeDtypeStruct 通常用于定义叶子值。

  • *args (Any) – 要传递给回调函数的参数

  • sharding (SingleDeviceSharding | None | None) – 可选分片,指定应从中调用回调的设备。

  • vmap_method (str | None | None) – 字符串,指定回调如何在 vmap() 下转换,如上所述。

  • **kwargs (Any) – 要传递给回调函数的关键字参数

  • vectorized (bool | None | DeprecatedArg)

返回:

一个 jax.Array 对象的 pytree,其结构与

result_shape_dtypes 的结构匹配.

返回类型:

result

另请参阅

示例

pure_callbackvmap() 下的行为由 vmap_method 参数控制,如上所述。考虑一些明确的示例来演示语义是有用的。例如,考虑以下函数

>>> def callback(x, y):
...   print(jnp.shape(x), jnp.shape(y))
...   return x + y
>>> def fun(x, y, *, vmap_method):
...   shape = jnp.broadcast_shapes(jnp.shape(x), jnp.shape(y))
...   dtype = jnp.result_type(x, y)
...   out_type = jax.ShapeDtypeStruct(shape, dtype)
...   return jax.pure_callback(callback, out_type, x, y,
...                            vmap_method=vmap_method)

使用 vmap_method="expand_dims" 调用会向 y 添加大小为 1 的新轴

>>> from functools import partial
>>> x = jnp.arange(4)
>>> y = 1.0
>>> jax.vmap(partial(fun, vmap_method="expand_dims"), in_axes=(0, None))(x, y)
(4,) (1,)
Array([1., 2., 3., 4.], dtype=float32)

然而,vmap_method="broadcast_all" 会向 y 添加大小为 4 的轴

>>> jax.vmap(partial(fun, vmap_method="broadcast_all"),
...          in_axes=(0, None))(x, y)
(4,) (4,)
Array([1., 2., 3., 4.], dtype=float32)