jax.tree_util.build_tree#
- jax.tree_util.build_tree(treedef, xs)[source]#
从嵌套的可迭代结构构建 treedef
已弃用:请改用
jax.tree.unflatten()
。- 参数:
treedef (PyTreeDef) – 要构建的 PyTreeDef 结构。
xs (Any) – 与 treedef 的元数匹配的嵌套可迭代对象
- 返回值:
具有 treedef 定义的结构的对象
- 返回类型:
Any
另请参阅
示例
>>> import jax >>> tree = [(1, 2), {'a': 3, 'b': 4}] >>> treedef = jax.tree.structure(tree) >>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13]) [(10, 11), {'a': 12, 'b': 13}]