jax.debug.callback#
- jax.debug.callback(callback, *args, ordered=False, partitioned=False, **kwargs)[source]#
调用一个可分阶段的 Python 回调。
欲了解更多信息,请参阅外部回调。
jax.debug.callback
允许您传入一个可在 JAX 阶段化程序中调用的 Python 函数。jax.debug.callback
遵循现有的 JAX 转换的纯操作语义,因此不感知副作用。这意味着在存在高阶原语和转换的情况下,效果可能会被丢弃、复制或潜在地重新排序。我们希望这种行为,因为我们希望
jax.debug.callback
是“无害的”,也就是说,我们希望这些原语尽可能少地改变 JAX 计算,同时尽可能多地揭示有关它们的信息,例如计算的哪些部分被复制或丢弃。- 参数:
- 返回:
无
- 返回类型:
无
另请参阅
jax.experimental.io_callback()
:专为非纯函数设计的回调。jax.pure_callback()
:专为纯函数设计的回调。jax.debug.print()
: 为打印设计的通用回调函数。