jax.tree_util.tree_structure# jax.tree_util.tree_structure(tree, is_leaf=None)[source]# jax.tree.structure()的别名。 参数: tree (Any) is_leaf (None | Callable[[Any], bool]) 返回类型: PyTreeDef