jax.debug.print#
- jax.debug.print(fmt, *args, ordered=False, partitioned=False, **kwargs)[source]#
打印值,并在 staged out JAX 函数中工作。
此函数不适用于 f-strings,因为格式化被延迟。因此,不要使用
jax.debug.print(f"hello {bar}")
,而是编写jax.debug.print("hello {bar}", bar=bar)
。此函数是
jax.debug.callback()
的一个简单的便捷包装器。实现本质上是def debug_print(fmt: str, *args, **kwargs): jax.debug.callback( lambda *args, **kwargs: print(fmt.format(*args, **kwargs)), *args, **kwargs)
直接调用
jax.debug.callback()
而不是这个便捷包装器可能很有用。例如,要在日志中获取调试打印,您可以将jax.debug.callback()
与logging.log
一起使用。- 参数:
fmt (str) – 格式字符串,例如
"hello {x}"
,它将用于格式化输入参数,就像str.format
一样。请参阅关于 字符串格式化 和 格式字符串语法 的 Python 文档。*args – 要格式化的位置参数列表,如同传递给
fmt.format
一样。ordered (bool) – 仅关键字参数,用于指示 staged out 计算是否将强制执行此
jax.debug.print
相对于其他 orderedjax.debug.print
调用的顺序。partitioned (bool) – 如果为 True,则仅打印本地 shards;此选项避免了操作数的所有收集。如果为 False,则使用逻辑操作数打印;此选项首先需要操作数的所有收集。
**kwargs – 要格式化的其他关键字参数,如同传递给
fmt.format
一样。
- 返回类型:
None