jax.debug.print#
- jax.debug.print(fmt, *args, ordered=False, partitioned=False, **kwargs)[源]#
打印值并在分阶段的 JAX 函数中工作。
此函数不适用于 f-strings,因为格式化是延迟执行的。因此,请不要使用
jax.debug.print(f"hello {bar}")
,而应编写jax.debug.print("hello {bar}", bar=bar)
。此函数是对
jax.debug.callback()
的轻量级便利封装。其实现本质上是def debug_print(fmt: str, *args, **kwargs): jax.debug.callback( lambda *args, **kwargs: print(fmt.format(*args, **kwargs)), *args, **kwargs)
直接调用
jax.debug.callback()
而非此便利封装可能会很有用。例如,要在日志中进行调试打印,您可以将jax.debug.callback()
与logging.log
结合使用。- 参数:
fmt (str) – 一个格式字符串,例如
"hello {x}"
,将用于格式化输入参数,类似于str.format
。请参阅 Python 文档中关于字符串格式化和格式字符串语法的内容。*args – 一个要格式化的位置参数列表,如同传递给
fmt.format
。ordered (bool) – 一个仅限关键字的参数,用于指示分阶段计算是否会强制此
jax.debug.print
相对于其他有序的jax.debug.print
调用之间的顺序。partitioned (bool) – 如果为 True,则仅打印局部分片;此选项可避免对操作数进行全收集。如果为 False,则使用逻辑操作数进行打印;此选项首先需要对操作数进行全收集。
**kwargs – 其他要格式化的关键字参数,如同传递给
fmt.format
。
- 返回类型:
无