jax.tree 模块#

用于处理树状容器数据结构的实用程序。

jax.tree 命名空间包含来自 jax.tree_util 的实用程序的别名。

函数列表#

all(tree, *[, is_leaf])

在树的叶子上调用 all()。

broadcast(prefix_tree, full_tree[, is_leaf])

将树前缀广播到给定树的完整结构中。

flatten(tree[, is_leaf])

展平一个 pytree。

flatten_with_path(tree[, is_leaf, ...])

tree_flatten 一样展平一个 pytree,但也会返回每个叶子的键路径。

leaves(tree[, is_leaf])

获取 pytree 的叶子。

leaves_with_path(tree[, is_leaf, ...])

tree_leaves 一样获取 pytree 的叶子,并返回每个叶子的键路径。

map(f, tree, *rest[, is_leaf])

将一个多输入函数映射到 pytree 参数上,以生成一个新的 pytree。

map_with_path(f, tree, *rest[, is_leaf, ...])

将一个多输入函数映射到 pytree 键路径和参数上,以生成一个新的 pytree。

reduce(function, tree[, initializer, is_leaf])

在树的叶子上调用 reduce()。

reduce_associative(operation, tree, *[, ...])

使用一个结合律二元运算对 pytree 执行归约。

structure(tree[, is_leaf])

获取 pytree 的 treedef。

transpose(outer_treedef, inner_treedef, ...)

将具有树结构 (outer, inner) 的树转换为具有结构 (inner, outer) 的树。

unflatten(treedef, leaves)

从 treedef 和叶子重构 pytree。