jax.tree_util.register_static#
- jax.tree_util.register_static(cls)[源代码]#
将 cls 注册为一个没有叶节点的 pytree。
实例将被
jax.jit()、jax.pmap()等函数视为静态。这可以作为使用jit的static_argnums和static_argnameskwargs,pmap的static_broadcasted_argnums等来标记输入的静态参数的替代方法。- 参数:
cls (type[H]) – 要注册为静态的类型。必须是可哈希的,如 https://docs.pythonlang.cn/3/glossary.html#term-hashable 中定义。
- 返回:
输入类
cls在添加到 JAX 的 pytree 注册表后将保持不变。这允许register_static用作装饰器。- 返回类型:
type[H]
示例
>>> import jax >>> @jax.tree_util.register_static ... class StaticStr(str): ... pass
现在,这个静态字符串可以直接在
jax.jit()编译的函数中使用,而无需使用static_argnums来标记变量为静态。>>> @jax.jit ... def f(x, y, s): ... return x + y if s == 'add' else x - y ... >>> f(1, 2, StaticStr('add')) Array(3, dtype=int32, weak_type=True)