jax.tree.flatten#
- jax.tree.flatten(tree, is_leaf=None)[源代码]#
展平一个 Pytree。
展平的顺序(即输出列表元素的顺序)是确定的,对应于从左到右的深度优先树遍历。
- 参数:
tree (Any) – 要展平的 pytree。
is_leaf (Callable[[Any], bool] | None) – 一个可选指定的函数,将在每次展平步骤时调用。它应该返回一个布尔值,true 表示停止遍历并将整个子树视为一个叶子,false 表示应该遍历当前对象进行展平。
- 返回:
一个对,其中第一个元素是叶子值的列表,第二个元素是表示展平树结构的 treedef。
- 返回类型:
示例
>>> import jax >>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]]) >>> vals [1, 2, 3, 4, 5] >>> treedef PyTreeDef([*, (*, *), [*, *]])