jax.tree.flatten_with_path#
- jax.tree.flatten_with_path(tree, is_leaf=None, is_leaf_takes_path=False)[源代码]#
扁平化一个 pytree,类似
tree_flatten
,但也返回每个叶子的键路径。- 参数:
- 返回:
一个对,第一个元素是键-叶对的列表,每个键-叶对包含一个叶子及其键路径。第二个元素是表示扁平化树的结构的 treedef。
- 返回类型:
tuple[list[tuple[tree_util.KeyPath, Any]], tree_util.PyTreeDef]
示例
>>> import jax >>> path_vals, treedef = jax.tree.flatten_with_path([1, {'x': 3}]) >>> path_vals [((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)] >>> treedef PyTreeDef([*, {'x': *}])