jax.tree.leaves_with_path#

jax.tree.leaves_with_path(tree, is_leaf=None, is_leaf_takes_path=False)[source]#

获取类似 tree_leaves 的 pytree 的叶节点,并返回每个叶节点的键路径。

参数:
  • tree (Any) – 一个 pytree。 如果它包含自定义类型,建议使用 register_pytree_with_keys 注册。

  • is_leaf (Callable[..., bool] | None)

  • is_leaf_takes_path (bool)

返回:

键-叶节点对的列表,每对包含一个叶节点及其键路径。

返回类型:

list[tuple[tree_util.KeyPath, Any]]

示例

>>> import jax
>>> jax.tree.leaves_with_path([1, {'x': 3}])
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]