jax.named_scope#

jax.named_scope(name)[source]#

一个上下文管理器,用于向 JAX 命名栈添加用户指定的名称。

当 JAX 将计算分阶段输出以进行即时编译到 XLA(或其他后端,如 TensorFlow)时,默认情况下不会保留其遇到的 Python 函数的名称(或其他源元数据)。这使得调试程序的分阶段输出(和/或编译后)表示变得复杂,因为每个正在执行的操作的上下文信息有限。

named_scope 告诉 JAX 在底层操作上添加额外的注解来分阶段输出给定函数。JAX 在内部的命名栈中跟踪这些注解。当分阶段输出的程序使用 XLA 编译时,这些注解会被保留,并显示在 TensorFlow Profiler(在 TensorBoard 中)等调试工具中。当使用 experimental.jax2tf.convert() 将 JAX 程序分阶段输出到 TensorFlow 时,名称也会被保留。

参数:

name (str) – 用于命名在命名作用域内创建的所有操作的前缀。

返回:

返回 None,但进入一个上下文,其中 name 将被附加到活跃的命名栈中。

返回类型:

source_info_util.ExtendNameStackContextManager

示例

named_scope 可以在编译函数内部用作上下文管理器

>>> import jax
>>>
>>> @jax.jit
... def layer(w, x):
...   with jax.named_scope("dot_product"):
...     logits = w.dot(x)
...   with jax.named_scope("activation"):
...     return jax.nn.relu(logits)

它也可以用作装饰器

>>> @jax.jit
... @jax.named_scope("layer")
... def layer(w, x):
...   logits = w.dot(x)
...   return jax.nn.relu(logits)