jax.debug.print#
- jax.debug.print(fmt, *args, ordered=False, **kwargs)[源代码]#
打印值,并可在分阶段输出的 JAX 函数中使用。
此函数不适用于 f 字符串,因为格式化是延迟的。因此,请不要使用
jax.debug.print(f"hello {bar}")
,而是使用jax.debug.print("hello {bar}", bar=bar)
。此函数是
jax.debug.callback()
的一个简单的便捷包装器。其实现本质上是def debug_print(fmt: str, *args, **kwargs): jax.debug.callback( lambda *args, **kwargs: print(fmt.format(*args, **kwargs)), *args, **kwargs)
直接调用
jax.debug.callback()
而不是使用此便捷包装器可能很有用。例如,要在日志中进行调试打印,您可以使用jax.debug.callback()
与logging.log
一起使用。