jax.tree_util.register_static#
- jax.tree_util.register_static(cls)[源代码]#
将 cls 注册为一个没有叶节点的 pytree。
实例被
jax.jit()
、jax.pmap()
等函数视为静态。这可以替代使用jit
的static_argnums
和static_argnames
kwargs,以及pmap
的static_broadcasted_argnums
等方式来标记输入为静态。- 参数:
cls (type[H]) – 要注册为静态的类型。必须是可哈希的,如 https://docs.pythonlang.cn/3/glossary.html#term-hashable 中所定义。
- 返回值:
输入类
cls
在添加到 JAX 的 pytree 注册表后返回,并且保持不变。这允许将register_static
用作装饰器。- 返回类型:
type[H]
示例
>>> import jax >>> @jax.tree_util.register_static ... class StaticStr(str): ... pass
这个静态字符串现在可以直接在
jax.jit()
编译的函数中使用,而无需使用static_argnums
标记变量为静态。>>> @jax.jit ... def f(x, y, s): ... return x + y if s == 'add' else x - y ... >>> f(1, 2, StaticStr('add')) Array(3, dtype=int32, weak_type=True)