jax.debug.breakpoint#
- jax.debug.breakpoint(*, backend=None, filter_frames=True, num_frames=None, ordered=False, token=None, **kwargs)[source]#
在程序中的某个位置设置断点。
- 参数:
backend (str | None) – 要使用的调试器后端。默认情况下,选择最高优先级的调试器,在没有其他注册调试器的情况下,回退到 CLI 调试器。
filter_frames (bool) – 是否从追溯信息中过滤掉 JAX 内部的堆栈帧。由于一些库(如 Flax)也使用了 JAX 的堆栈帧过滤系统,此选项也可能影响是否过滤掉来自库的堆栈帧。
num_frames (int | None) – 在交互式调试器中可供检查的当前堆栈帧上方的帧数。
ordered (bool) – 一个仅限关键字的参数,用于指示分阶段计算是否会强制此
jax.debug.breakpoint
相对于其他有序的jax.debug.breakpoint
和jax.debug.print
调用的顺序。token – 一个仅限关键字的参数;作为
ordered
的替代方案。如果使用,则应传入一个 JAX 数组(或 JAX 数组的 pytree),断点将在其值计算完成后运行。此参数将原样返回,并应传回计算中。如果返回值在后续计算中未使用,则整个计算将被剪枝,并且此断点将不会运行。
- 返回:
如果传入 token,则其值将原样返回。否则,返回 None。