jax.tree.map_with_path#
- jax.tree.map_with_path(f, tree, *rest, is_leaf=None, is_leaf_takes_path=False)[source]#
将一个多输入函数映射到 pytree 键路径和参数上,以生成一个新的 pytree。
这是
tree_map
的一个更强大的替代方案,它可以将每个叶子的键路径作为输入参数。- 参数:
- 返回:
一个新的 pytree,具有与
tree
相同的结构,但每个叶子的值由f(kp, x, *xs)
给出,其中kp
是tree
中相应叶子的叶子的键路径,x
是叶子的值,xs
是rest
中相应节点的元组值。- 返回类型:
任意类型
示例
>>> import jax >>> jax.tree.map_with_path(lambda path, x: x + path[0].idx, [1, 2, 3]) [1, 3, 5]