jax.tree.leaves#

jax.tree.leaves(tree, is_leaf=None)[源]#

获取 pytree 的叶子节点。

参数:
  • tree (Any) – 要获取其叶子节点的 pytree

  • is_leaf (Callable[[Any], bool] | None) – 一个可选函数,将在每次展平步骤中调用。它应返回一个布尔值,指示展平是否应遍历当前对象,或者是否应立即停止,并将整个子树视为叶子节点。

返回:

一个树叶子节点的列表。

返回类型:

leaves

示例

>>> import jax
>>> jax.tree.leaves([1, (2, 3), [4, 5]])
[1, 2, 3, 4, 5]