jax.tree.unflatten#
- jax.tree.unflatten(treedef, leaves)[源代码]#
从 treedef 和 leaves 重构 pytree。
与
tree_flatten()
相反。- 参数:
treedef (tree_util.PyTreeDef) – 用于重构的 treedef。
leaves (Iterable[tree_util.Leaf]) – 用于重构的 leaves 可迭代对象。该可迭代对象必须与 treedef 的 leaves 匹配。
- 返回:
重构后的 pytree,其中
leaves
按照treedef
描述的结构放置。- 返回类型:
任意类型
示例
>>> import jax >>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]]) >>> newvals = [100, 200, 300, 400, 500] >>> jax.tree.unflatten(treedef, newvals) [100, (200, 300), [400, 500]]