jax.pure_callback#

jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=Deprecated, vmap_method=None, **kwargs)[source]#

调用纯 Python 回调函数。在 jit()/vmap() 等函数下有效。

欲了解更多说明,请参阅外部回调

pure_callback 允许在 JIT 编译的 JAX 函数中调用 Python 函数。输入 callback 将传递放置在本地 CPU 上的 JAX 数组,并且它也应该返回 CPU 上的 JAX 数组。

该回调函数被视为函数式纯净的,这意味着它没有副作用,并且其输出值仅取决于其参数值。因此,它可以安全地被多次调用(例如当通过 vmap()pmap() 转换时),或者在例如 jit 修饰的函数的输出对其值没有数据依赖时完全不被调用。如果数据依赖允许,纯回调函数也可以被重新排序。

警告

在 JAX 转换的上下文中,Python 异常应被视为副作用:这意味着在 pure_callback 中有意引发错误会破坏 API 约定,并且由此产生的程序的行为是未定义的。

当进行 vmap 转换时,其行为将取决于 vmap_method 的值。

  • 在没有显式 vmap_method 的回调函数上调用 vmap() 将引发 NotImplementedError

  • vmap_method="sequential" 使用 map() 循环遍历批处理参数,为每个批处理元素调用一次 callback

  • vmap_method="sequential_unrolled" 类似于 sequential,但循环是展开的。

  • vmap_method="expand_dims" 调用 callback,其中大小为 1 的新轴作为非批处理输入的主维度添加。

  • vmap_method="broadcast_all" 的行为类似于 expand_dims,但输入会被平铺到预期的批处理形状。

如有必要,可以通过使用 vmap_method="legacy_vectorized" 恢复已移除的 vectorized=True 参数所提供的旧版行为。

当前默认行为是在未指定时使用 vmap_method="sequential",但此行为已弃用,将来,除非显式指定 vmap_method,否则默认将引发 NotImplementedError

参数:
  • callback (Callable[..., Any]) – 在主机上执行的函数。该回调函数被假定为纯函数(即没有副作用的函数):如果传递非纯函数,它可能会以意想不到的方式运行,尤其是在转换下。可调用对象将以数组的 PyTree 作为参数传入,并应返回一个与 result_shape_dtypes 匹配的数组 PyTree。

  • result_shape_dtypes (Any) – 一个 PyTree,其叶子具有 shapedtype 属性,其结构与回调函数在运行时期的预期输出相匹配。jax.ShapeDtypeStruct 经常用于定义叶子值。

  • *args (Any) – 将传递给回调函数的参数

  • sharding (SingleDeviceSharding | None) – 可选的分片,指定应从中调用回调函数的设备。

  • vmap_method (str | None) – 字符串,指定回调函数在 vmap() 下如何转换,如上所述。

  • **kwargs (Any) – 将传递给回调函数的关键字参数

  • vectorized (bool | None | DeprecatedArg)

返回:

一个由 jax.Array 对象组成的 PyTree,其结构与

result_shape_dtypes.

返回类型:

结果

另请参阅

示例

pure_callbackvmap() 下的行为由 vmap_method 参数控制,如上所述。考虑一些明确的示例来演示其语义是很有用的。例如,考虑以下函数

>>> def callback(x, y):
...   print(jnp.shape(x), jnp.shape(y))
...   return x + y
>>> def fun(x, y, *, vmap_method):
...   shape = jnp.broadcast_shapes(jnp.shape(x), jnp.shape(y))
...   dtype = jnp.result_type(x, y)
...   out_type = jax.ShapeDtypeStruct(shape, dtype)
...   return jax.pure_callback(callback, out_type, x, y,
...                            vmap_method=vmap_method)

使用 vmap_method="expand_dims" 调用它会在 y 中添加一个大小为 1 的新轴

>>> from functools import partial
>>> x = jnp.arange(4)
>>> y = 1.0
>>> jax.vmap(partial(fun, vmap_method="expand_dims"), in_axes=(0, None))(x, y)
(4,) (1,)
Array([1., 2., 3., 4.], dtype=float32)

vmap_method="broadcast_all" 会在 y 中添加一个大小为 4 的轴

>>> jax.vmap(partial(fun, vmap_method="broadcast_all"),
...          in_axes=(0, None))(x, y)
(4,) (4,)
Array([1., 2., 3., 4.], dtype=float32)