jax.tree_util 模块
用于处理树状容器数据结构的实用工具。
本模块提供了一组小型实用函数,用于处理树状数据结构,例如嵌套的元组、列表和字典。我们将这些结构称为 pytrees。它们是树状的,因为它们是递归定义的(任何非 pytree 都是一个 pytree,即叶子,任何 pytree 的 pytree 也是一个 pytree),并且可以递归地进行操作(映射操作不会保留对象标识等价性,并且结构不能包含引用循环)。
被视为 pytree 节点(例如,可以被映射,而不是被视为叶子)的 Python 类型集是可扩展的。存在一个模块级别的类型注册表,类层次结构被忽略。通过注册一个新的 pytree 节点类型,该类型实际上对该文件中的实用函数变得透明。
本模块的主要目的是实现用户定义数据结构与 JAX 转换(例如 jit)之间的互操作性。它无意成为一个通用的树状数据结构处理库。
有关示例,请参阅 JAX pytrees 说明。
函数列表
Partial(func, *args, **kw)
|
functools.partial 的一个版本,适用于 pytrees。 |
all_leaves(iterable[, is_leaf])
|
测试给定可迭代对象中的所有元素是否都是叶子。 |
register_dataclass(nodetype[, data_fields, ...])
|
扩展被视为 pytrees 内部节点的类型集。 |
register_pytree_node(nodetype, flatten_func, ...)
|
扩展被视为 pytrees 内部节点的类型集。 |
register_pytree_node_class(cls)
|
扩展被视为 pytrees 内部节点的类型集。 |
register_pytree_with_keys(nodetype, ...[, ...])
|
扩展被视为 pytrees 内部节点的类型集。 |
register_pytree_with_keys_class(cls)
|
扩展被视为 pytrees 内部节点的类型集。 |
register_static(cls)
|
将 cls 注册为没有叶子的 pytree。 |
tree_flatten_with_path(tree[, is_leaf, ...])
|
是 jax.tree.flatten_with_path() 的别名。 |
tree_leaves_with_path(tree[, is_leaf, ...])
|
是 jax.tree.leaves_with_path() 的别名。 |
tree_map_with_path(f, tree, *rest[, ...])
|
是 jax.tree.map_with_path() 的别名。 |
treedef_children(treedef)
|
返回直接子节点的 treedef 列表 |
treedef_is_leaf(treedef)
|
如果 treedef 代表叶子,则返回 True。 |
treedef_tuple(treedefs)
|
从子 treedefs 的可迭代对象创建元组 treedef。 |
KeyEntry
|
类型变量。 |
KeyPath
|
内置的不可变序列。 |
keystr(keys, *[, simple, separator])
|
用于美观打印键元组的辅助函数。 |
旧版 API
这些 API 现在通过 jax.tree 访问。
tree_all(tree, *[, is_leaf])
|
是 jax.tree.all() 的别名。 |
tree_broadcast(prefix_tree, full_tree[, is_leaf])
|
是 jax.tree.broadcast() 的别名。 |
tree_flatten(tree[, is_leaf])
|
是 jax.tree.flatten() 的别名。 |
tree_leaves(tree[, is_leaf])
|
是 jax.tree.leaves() 的别名。 |
tree_map(f, tree, *rest[, is_leaf])
|
是 jax.tree.map() 的别名。 |
tree_reduce(function, tree[, initializer, ...])
|
是 jax.tree.reduce() 的别名。 |
tree_reduce_associative(operation, tree, *)
|
是 jax.tree.reduce_associative() 的别名。 |
tree_structure(tree[, is_leaf])
|
是 jax.tree.structure() 的别名。 |
tree_transpose(outer_treedef, inner_treedef, ...)
|
是 jax.tree.transpose() 的别名。 |
tree_unflatten(treedef, leaves)
|
是 jax.tree.unflatten() 的别名。 |