jax.tree_util.tree_map#
- jax.tree_util.tree_map(f, tree, *rest, is_leaf=None)[source]#
的别名
jax.tree.map()
。- 参数:
f (可调用对象[..., 任意类型])
tree (Any)
rest (Any)
is_leaf (Callable[[Any], bool] | None)
- 返回类型:
任意类型
的别名 jax.tree.map()
。
f (可调用对象[..., 任意类型])
tree (Any)
rest (Any)
is_leaf (Callable[[Any], bool] | None)
任意类型