jax.tree.all#

jax.tree.all(tree, *, is_leaf=None)[源代码]#

遍历树的叶子节点并调用 all()。

参数:
  • tree (Any) – 要评估的 pytree

  • is_leaf (Callable[[Any], bool] | None) – 一个可选的函数,将在每个展平步骤中调用。它应该返回一个布尔值,指示是否应遍历当前对象,或者是否应立即停止,并将整个子树视为一个叶子节点。

返回:

布尔值 True 或 False

返回类型:

result

示例

>>> import jax
>>> jax.tree.all([True, {'a': True, 'b': (True, True)}])
True
>>> jax.tree.all([False, (True, False)])
False