jax.tree_util.tree_map_with_path# jax.tree_util.tree_map_with_path(f, tree, *rest, is_leaf=None)[来源]# 别名:jax.tree.map_with_path()。 参数: f (Callable[..., Any]) tree (Any) rest (Any) is_leaf (Callable[[Any], bool] | None | None) 返回类型: Any