jax.tree_util.tree_transpose#
- jax.tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)[source]#
的别名
jax.tree.transpose()
。- 参数:
outer_treedef (PyTreeDef)
inner_treedef (PyTreeDef | None)
pytree_to_transpose (Any)
- 返回类型:
任意类型
的别名 jax.tree.transpose()
。
outer_treedef (PyTreeDef)
inner_treedef (PyTreeDef | None)
pytree_to_transpose (Any)
任意类型