jax.tree.structure#

jax.tree.structure(tree, is_leaf=None)[源]#

获取 pytree 的结构定义。

参数:
  • tree (Any) – 要从中获取叶子的 pytree。

  • is_leaf (None | Callable[[Any], bool]) – 一个可选的函数,将在每个展平步骤中调用。它应该返回一个布尔值,指示是否应该遍历当前对象的展平,或者是否应该立即停止,并将整个子树视为一个叶子。

返回:

一个 PyTreeDef,表示树的结构。

返回类型:

pytreedef

示例

>>> import jax
>>> jax.tree.structure([1, (2, 3), [4, 5]])
PyTreeDef([*, (*, *), [*, *]])