jax.debug 模块#

运行时值调试实用程序#

编译后的打印和断点 描述了如何使用 JAX 的运行时值调试功能。

callback(callback, *args[, ordered, partitioned])

调用一个可分阶段的 Python 回调。

print(fmt, *args[, ordered, partitioned, ...])

打印值并在分阶段的 JAX 函数中工作。

breakpoint(*[, backend, filter_frames, ...])

在程序中的某个点进入断点。

分片调试实用程序#

允许在(以及程序外的)已编译函数中检查和可视化数组分片的功能。

inspect_array_sharding(value, *, callback)

允许在 JIT 编译的函数中检查数组分片。

visualize_array_sharding(arr, **kwargs)

可视化数组的分片。

visualize_sharding(shape, sharding, *[, ...])

使用 rich 可视化 Sharding