jax.tree.reduce#

jax.tree.reduce(function, tree, initializer=<jax._src.tree_util.Unspecified object>, is_leaf=None)[源代码]#

对树的叶子进行 reduce() 调用。

参数:
  • function (Callable[[T, Any], T]) – 归约函数

  • tree (Any) – 要归约的 pytree

  • initializer (T | tree_util.Unspecified) – 可选的初始值

  • is_leaf (Callable[[Any], bool] | None) – 一个可选的函数,将在每个展平步骤中调用。它应该返回一个布尔值,指示是否应遍历当前对象,或者是否应立即停止,并将整个子树视为一个叶子节点。

返回:

归约后的值。

返回类型:

result

示例

>>> import jax
>>> import operator
>>> jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]])
21

注意事项

提示:您可以通过先使用 jax.tree.map() 将叶子映射为 None 来排除它们。这样之后它们就不会被计为叶子。