jax.tree_util.register_pytree_node#
- jax.tree_util.register_pytree_node(nodetype, flatten_func, unflatten_func, flatten_with_keys_func=None)[源代码]#
扩展被视为 pytrees 内部节点的类型集。
请参阅 示例用法。
- 参数:
nodetype (type[T]) – 要注册为 pytree 的 Python 类型。
flatten_func (Callable[[T], tuple[_Children, _AuxData]]) – 在展平过程中使用的函数,接收一个
nodetype类型的对象,并返回一个对:(1) 一个可迭代对象,用于递归展平子节点,(2) 一些可哈希的辅助数据,用于存储在 treedef 中并传递给unflatten_func。unflatten_func (Callable[[_AuxData, _Children], T]) – 一个接受两个参数的函数:由
flatten_func返回并存储在 treedef 中的辅助数据,以及展平后的子节点。该函数应返回一个nodetype的实例。flatten_with_keys_func (Callable[[T], tuple[KeyLeafPairs, _AuxData]] | None)
- 返回类型:
无
另请参阅
register_static():注册静态 pytree 的简化 API。register_dataclass():注册 dataclass 的简化 API。
示例
首先,我们将定义一个自定义类型
>>> class MyContainer: ... def __init__(self, size): ... self.x = jnp.zeros(size) ... self.y = jnp.ones(size) ... self.size = size
如果我们尝试在 JIT 编译的函数中使用它,我们会收到一个错误,因为 JAX 尚未知道如何处理此类型
>>> m = MyContainer(size=5) >>> def f(m): ... return m.x + m.y + jnp.arange(m.size) >>> jax.jit(f)(m) Traceback (most recent call last): ... TypeError: Cannot interpret value of type <class 'jax.tree_util.MyContainer'> as an abstract array; it does not have a dtype attribute
为了让 JAX 识别我们的对象,我们必须将其注册为 pytree
>>> def flatten_func(obj): ... children = (obj.x, obj.y) # children must contain arrays & pytrees ... aux_data = (obj.size,) # aux_data must contain static, hashable data. ... return (children, aux_data) ... >>> def unflatten_func(aux_data, children): ... # Here we avoid `__init__` because it has extra logic we don't require: ... obj = object.__new__(MyContainer) ... obj.x, obj.y = children ... obj.size, = aux_data ... return obj ... >>> jax.tree_util.register_pytree_node(MyContainer, flatten_func, unflatten_func)
有了这个定义,我们就可以在 JIT 编译的函数中使用此类型的实例了。
>>> jax.jit(f)(m) Array([1., 2., 3., 4., 5.], dtype=float32)