jax.tree.map_with_path#

jax.tree.map_with_path(f, tree, *rest, is_leaf=None, is_leaf_takes_path=False)[源代码]#

将一个多输入函数映射到 pytree 键路径和参数上,以生成一个新的 pytree。

这是比 tree_map 更强大的替代方案,它可以将每个叶子的键路径作为输入参数。

参数:
  • f (Callable[..., Any]) – 函数,接受 2 + len(rest) 个参数,即键路径和 pytree 的每个相应叶子。

  • tree (Any) – 要映射的 pytree,其中每个叶子的键路径作为第一个位置参数,叶子本身作为 f 的第二个参数。

  • *rest (Any) – 一个 pytree 元组,其中每个 pytree 的结构与 tree 相同,或者以 tree 作为前缀。

  • is_leaf (Callable[..., bool] | None)

  • is_leaf_takes_path (bool)

返回:

一个新的 pytree,其结构与 tree 相同,但在每个叶子的值由 f(kp, x, *xs) 给出,其中 kptree 中对应叶子的键路径,x 是叶子值,xsrest 中对应节点的值元组。

返回类型:

任意类型

示例

>>> import jax
>>> jax.tree.map_with_path(lambda path, x: x + path[0].idx, [1, 2, 3])
[1, 3, 5]