jax.named_scope#
- jax.named_scope(name)[源代码]#
一个上下文管理器,它将用户指定的名称添加到 JAX 名称堆栈中。
在为即时编译到 XLA(或其他后端,如 TensorFlow)而暂存计算时,JAX 默认情况下不会保留其遇到的 Python 函数的名称(或其他源元数据)。这使得调试程序的暂存(和/或编译)表示变得复杂,因为对于每个正在执行的操作,上下文信息有限。
named_scope告诉 JAX 暂存给定的函数,并在底层操作上添加附加注解。JAX 在内部会跟踪这些注解,并将其存储在一个名称堆栈中。当使用 XLA 编译暂存的程序时,这些注解会被保留,并会显示在 TensorBoard 中的 TensorFlow Profiler 等调试工具中。使用experimental.jax2tf.convert()将 JAX 程序暂存到 TensorFlow 时,也会保留名称。- 参数:
name (str) – 在命名作用域内创建的所有操作要使用的前缀。
- 产生:
产生
None,但进入一个上下文,其中 name 将被追加到活动名称堆栈中。- 返回类型:
source_info_util.ExtendNameStackContextManager
示例
named_scope可以在编译后的函数内部用作上下文管理器>>> import jax >>> >>> @jax.jit ... def layer(w, x): ... with jax.named_scope("dot_product"): ... logits = w.dot(x) ... with jax.named_scope("activation"): ... return jax.nn.relu(logits)
它也可以用作装饰器
>>> @jax.jit ... @jax.named_scope("layer") ... def layer(w, x): ... logits = w.dot(x) ... return jax.nn.relu(logits)