jax.tree_util.register_static#

jax.tree_util.register_static(cls)[源代码]#

cls 注册为一个没有叶节点的 pytree。

实例被 jax.jit()jax.pmap() 等函数视为静态。这可以替代使用 jitstatic_argnumsstatic_argnames kwargs,以及 pmapstatic_broadcasted_argnums 等方式来标记输入为静态。

参数:

cls (type[H]) – 要注册为静态的类型。必须是可哈希的,如 https://docs.pythonlang.cn/3/glossary.html#term-hashable 中所定义。

返回值:

输入类 cls 在添加到 JAX 的 pytree 注册表后返回,并且保持不变。这允许将 register_static 用作装饰器。

返回类型:

type[H]

示例

>>> import jax
>>> @jax.tree_util.register_static
... class StaticStr(str):
...   pass

这个静态字符串现在可以直接在 jax.jit() 编译的函数中使用,而无需使用 static_argnums 标记变量为静态。

>>> @jax.jit
... def f(x, y, s):
...   return x + y if s == 'add' else x - y
...
>>> f(1, 2, StaticStr('add'))
Array(3, dtype=int32, weak_type=True)