jax.debug.callback#

jax.debug.callback(callback, *args, ordered=False, partitioned=False, **kwargs)[source]#

调用可分阶段的 Python 回调。

有关更多解释,请参阅外部回调

jax.debug.callback 使您能够传入一个 Python 函数,该函数可以在分阶段的 JAX 程序内部调用。jax.debug.callback 遵循现有的 JAX 转换操作语义,因此不知道副作用。这意味着在存在高阶原语和转换的情况下,效果可能会被删除、复制或可能重新排序。

我们希望这种行为是因为我们希望 jax.debug.callback 是“无害的”,即我们希望这些原语尽可能少地更改 JAX 计算,同时尽可能多地揭示有关它们的信息,例如计算的哪些部分被复制或删除。

参数:
  • callback (Callable[..., None]) – 一个返回 None 的 Python 可调用对象。

  • *args (Any) – 回调的位置参数。

  • ordered (bool) – 一个仅关键字参数,用于指示分阶段计算是否将强制执行此回调相对于其他有序回调的顺序。

  • partitioned (bool) – 如果为 True,则仅打印本地分片;此选项避免了操作数的全收集。如果为 False,则使用逻辑操作数打印;此选项首先需要操作数的全收集。

  • **kwargs (Any) – 回调的关键字参数。

返回值:

None

返回类型:

None

另请参阅