jax.tree
模块
用于处理树状容器数据结构的实用工具。
jax.tree
命名空间包含来自 jax.tree_util
的实用工具的别名。
函数列表
all (tree, *[, is_leaf])
|
对树的叶子调用 all()。 |
broadcast (prefix_tree, full_tree[, is_leaf])
|
将树的前缀广播到给定树的完整结构中。 |
flatten (tree[, is_leaf])
|
展平一个 pytree。 |
flatten_with_path (tree[, is_leaf, ...])
|
像 tree_flatten 一样展平一个 pytree,但也返回每个叶子的键路径。 |
leaves (tree[, is_leaf])
|
获取 pytree 的叶子。 |
leaves_with_path (tree[, is_leaf, ...])
|
像 tree_leaves 一样获取 pytree 的叶子,并返回每个叶子的键路径。 |
map (f, tree, *rest[, is_leaf])
|
将一个多输入函数映射到 pytree 参数上,以生成一个新的 pytree。 |
map_with_path (f, tree, *rest[, is_leaf, ...])
|
将一个多输入函数映射到 pytree 键路径和参数上,以生成一个新的 pytree。 |
归约 ()
|
对树的叶子调用 reduce()。 |
reduce_associative (operation, tree, *[, ...])
|
使用关联二元运算对 pytree 执行归约。 |
structure (tree[, is_leaf])
|
获取 pytree 的 treedef。 |
transpose (outer_treedef, inner_treedef, ...)
|
将具有树结构(外部,内部)的树转换为具有结构(内部,外部)的树。 |
unflatten (treedef, leaves)
|
从 treedef 和叶子重建一个 pytree。 |