jax.tree.map_with_path#
- jax.tree.map_with_path(f, tree, *rest, is_leaf=None)[源代码]#
将多输入函数映射到 pytree 键路径和参数,以生成新的 pytree。
这是
tree_map
的更强大的替代方案,它可以将每个叶节点的键路径作为输入参数。- 参数:
f (Callable[..., Any]) – 接受
2 + len(rest)
个参数的函数,即键路径和 pytree 的每个对应叶节点。tree (Any) – 要映射的 pytree,其中每个叶节点的键路径作为第一个位置参数,叶节点本身作为
f
的第二个参数。*rest (Any) – pytree 的元组,每个 pytree 具有与
tree
相同的结构,或者以tree
作为前缀。is_leaf (Callable[[Any], bool] | None | None)
- 返回:
一个新的 pytree,其结构与
tree
相同,但每个叶节点的值由f(kp, x, *xs)
给出,其中kp
是tree
中相应叶节点的叶节点键路径,x
是叶节点值,xs
是rest
中相应节点处的值元组。- 返回类型:
Any
示例
>>> import jax >>> jax.tree.map_with_path(lambda path, x: x + path[0].idx, [1, 2, 3]) [1, 3, 5]