jax.tree_util.tree_structure# jax.tree_util.tree_structure(tree, is_leaf=None)[源代码]# 别名:jax.tree.structure()。 参数: tree (Any) is_leaf (None | Callable[[Any], bool]) 返回类型: PyTreeDef