jax.tree_util.register_pytree_with_keys#

jax.tree_util.register_pytree_with_keys(nodetype, flatten_with_keys, unflatten_func, flatten_func=None)[source]#

扩展了被视为 pytrees 内部节点的类型集合。

这是 register_pytree_node 的更强大的替代方案,允许您在展平树和树映射时访问每个 pytree 叶子的键路径。

参数:
  • nodetype (type[T]) – 要视为内部 pytree 节点的 Python 类型。

  • flatten_with_keys (Callable[[T], tuple[Iterable[KeyLeafPair], _AuxData]]) – 一个在展平期间使用的函数,接受 nodetype 类型的值,并返回一个对,其中(1)一个可迭代对象,用于键路径及其子节点的元组,以及(2)一些可哈希的辅助数据,用于存储在 treedef 中,并传递给 unflatten_func

  • unflatten_func (Callable[[_AuxData, Iterable[Any]], T]) – 一个接受两个参数的函数:由 flatten_func 返回并存储在 treedef 中的辅助数据,以及未展平的子节点。该函数应返回 nodetype 的实例。

  • flatten_func (None | Callable[[T], tuple[Iterable[Any], _AuxData]] | None) – 一个可选函数,类似于 flatten_with_keys,但仅返回子节点和辅助数据。它必须以与 flatten_with_keys 相同的顺序返回子节点,并返回相同的辅助数据。此参数是可选的,仅在调用没有键的函数(如 tree_maptree_flatten)时需要以加快遍历速度。

示例

首先,我们将定义一个自定义类型

>>> class MyContainer:
...   def __init__(self, size):
...     self.x = jnp.zeros(size)
...     self.y = jnp.ones(size)
...     self.size = size

现在使用键感知展平函数注册它

>>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey
>>> def flatten_with_keys(obj):
...   children = [(GetAttrKey('x'), obj.x),
...               (GetAttrKey('y'), obj.y)]  # children must contain arrays & pytrees
...   aux_data = (obj.size,)  # aux_data must contain static, hashable data.
...   return children, aux_data
...
>>> def unflatten(aux_data, children):
...   # Here we avoid `__init__` because it has extra logic we don't require:
...   obj = object.__new__(MyContainer)
...   obj.x, obj.y = children
...   obj.size, = aux_data
...   return obj
...
>>> jax.tree_util.register_pytree_node(MyContainer, flatten_with_keys, unflatten)

现在这可以与 tree_flatten_with_path() 等函数一起使用

>>> m = MyContainer(4)
>>> leaves, treedef = jax.tree_util.tree_flatten_with_path(m)