jax.experimental.pallas.debug_print#

jax.experimental.pallas.debug_print(fmt, *args)[源代码]#

从 Pallas 内核内部打印值。

参数:
  • fmt (str) –

    要包含在输出中的格式字符串。格式字符串的限制取决于后端。

    • 在 GPU 上,使用 Triton 时,fmt 不能包含任何占位符({...}),因为它总是在任何值之前打印。

    • 在 GPU 上,使用实验性的 Mosaic GPU 后端时,fmt 必须包含一个占位符来打印每个值。不支持格式说明符和转换。如果提供单个值,该值可以是一个数组。否则,所有值都必须是标量。

    • 在 TPU 上,如果所有输入都是标量:如果 fmt 包含占位符,所有值都必须是 32 位整数。如果没有占位符,值将打印在格式字符串之后。

    • 在 TPU 上,如果输入是单个向量,则向量将打印在格式字符串之后。格式字符串必须以单个占位符 {} 结尾。

  • *args (jax.typing.ArrayLike) – 要打印的值。