jax.tree.reduce_associative#

jax.tree.reduce_associative(operation, tree, *, identity=<jax._src.tree_util.Unspecified object>, is_leaf=None)[源代码]#

使用结合律二元运算对 pytree 执行归约。

此函数利用运算的结合律,以并行方式(对数深度)执行归约。

参数:
  • operation (Callable[[T, T], T]) – 结合律二元运算

  • tree (Any) – 要归约的 pytree

  • identity (T | tree_util.Unspecified) – 结合律二元运算的单位元素。仅当树为空时才使用。否则是可选的。

  • is_leaf (Callable[[Any], bool] | None) – 一个可选的函数,将在每个扁平化步骤中调用。它应返回一个布尔值,指示是否应遍历当前对象,或者是否应立即停止,并将整个子树视为叶子。

返回:

归约后的值

返回类型:

result

示例

>>> import jax
>>> import operator
>>> jax.tree.reduce_associative(operator.add, [1, (2, 3), [4, 5, 6]])
21

另请参阅