jax.tree.leaves_with_path#
- jax.tree.leaves_with_path(tree, is_leaf=None, is_leaf_takes_path=False)[source]#
获取类似
tree_leaves
的 pytree 的叶节点,并返回每个叶节点的键路径。- 参数:
- 返回:
键-叶节点对的列表,每对包含一个叶节点及其键路径。
- 返回类型:
示例
>>> import jax >>> jax.tree.leaves_with_path([1, {'x': 3}]) [((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]