jax.tree_util.register_pytree_node_class#
- jax.tree_util.register_pytree_node_class(cls)[源代码]#
扩展了被视为 pytree 中内部节点的类型集合。
此函数是对
register_pytree_node
的一个简单的封装,并提供了一个面向类的接口。- 参数:
cls (Typ) – 要注册为 pytree 的类型
- 返回:
输入类
cls
在添加到 JAX 的 pytree 注册表后,将保持不变地返回。此返回值允许将register_pytree_node_class
用作装饰器。- 返回类型:
Typ
另请参阅
register_static()
: 用于注册静态 pytree 的更简单 API。register_dataclass()
: 用于注册数据类的更简单 API。
示例
在这里,我们将定义一个自定义容器,该容器将与
jax.jit()
和其他 JAX 转换兼容>>> import jax >>> @jax.tree_util.register_pytree_node_class ... class MyContainer: ... def __init__(self, x, y): ... self.x = x ... self.y = y ... def tree_flatten(self): ... return ((self.x, self.y), None) ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children) ... >>> m = MyContainer(jnp.zeros(4), jnp.arange(4)) >>> def f(m): ... return m.x + 2 * m.y >>> jax.jit(f)(m) Array([0., 2., 4., 6.], dtype=float32)