jax.tree.reduce_associative#
- jax.tree.reduce_associative(operation, tree, *, identity=<jax._src.tree_util.Unspecified object>, is_leaf=None)[源代码]#
使用结合律二元运算对 pytree 执行归约。
此函数利用运算的结合律,以并行方式(对数深度)执行归约。
- 参数:
operation (Callable[[T, T], T]) – 结合律二元运算
tree (Any) – 要归约的 pytree
identity (T | tree_util.Unspecified) – 结合律二元运算的单位元素。仅当树为空时才使用。否则是可选的。
is_leaf (Callable[[Any], bool] | None) – 一个可选的函数,将在每个扁平化步骤中调用。它应返回一个布尔值,指示是否应遍历当前对象,或者是否应立即停止,并将整个子树视为叶子。
- 返回:
归约后的值
- 返回类型:
result
示例
>>> import jax >>> import operator >>> jax.tree.reduce_associative(operator.add, [1, (2, 3), [4, 5, 6]]) 21
另请参阅