jax.debug 模块# 运行时值调试实用工具# 编译后的打印和断点 描述了如何使用 JAX 的运行时值调试功能。 callback(callback, *args[, ordered, partitioned]) 调用一个可分阶段的 Python 回调。 print(fmt, *args[, ordered, partitioned]) 打印值,并在 staged out 的 JAX 函数中工作。 breakpoint(*[, backend, filter_frames, ...]) 在程序中的某个点进入断点。 分片调试实用工具# 启用检查和可视化 staged 函数内部(和外部)的数组分片的函数。 inspect_array_sharding(value, *, callback) 启用在 JIT 编译的函数内部检查数组分片。 visualize_array_sharding(arr, **kwargs) 可视化数组的分片。 visualize_sharding(shape, sharding, *[, ...]) 使用 rich 可视化 Sharding。