jax.tree_util.tree_unflatten# jax.tree_util.tree_unflatten(treedef, leaves)[源代码]# 别名:jax.tree.unflatten()。 参数: treedef (PyTreeDef) leaves (Iterable[Leaf]) 返回类型: 任意类型