jax.named_call#
- jax.named_call(fun, *, name=None)[source]#
在 JAX 计算进行阶段化(staging out)时,为函数添加用户指定的名称。
当为针对 XLA(或其他后端如 TensorFlow)的即时编译进行计算阶段化时,JAX 会运行您的 Python 程序,但默认情况下不会保留任何函数名称或与之相关的元数据。这会使调试程序的阶段化(和/或已编译)表示变得复杂,因为每个正在执行的操作的上下文信息有限。
named_call 告知 JAX 将给定函数作为具有特定名称的子计算进行阶段化。当阶段化后的程序用 XLA 编译时,这些命名的子计算会被保留,并显示在 TensorFlow Profiler 等 TensorBoard 调试工具中。使用
experimental.jax2tf.convert()
将 JAX 程序阶段化到 TensorFlow 时,名称也会被保留。- 参数:
fun (F) – 要包装的函数。它可以是任何可调用对象(Callable)。
name (str | None) – 可选。用于命名在名称范围(name scope)内创建的所有子计算的前缀。如果未指定,则使用 fun.__name__。
- 返回:
一个包装在名称范围(name_scope)内的 fun 版本。
- 返回类型:
F