jax.tree.transpose#

jax.tree.transpose(outer_treedef, inner_treedef, pytree_to_transpose)[源]#

将具有 (outer, inner) 树结构的树转换为具有 (inner, outer) 结构的树。

参数:
  • outer_treedef (tree_util.PyTreeDef) – 表示外部树的 PyTreeDef。

  • inner_treedef (tree_util.PyTreeDef | None) – 表示内部树的 PyTreeDef。如果为 None,则将从 outer_treedef 和 pytree_to_transpose 的结构中推断出来。

  • pytree_to_transpose (Any) – 要转置的 pytree。

返回:

转置后的 pytree。

返回类型:

transposed_pytree

示例

>>> import jax
>>> tree = [(1, 2, 3), (4, 5, 6)]
>>> inner_structure = jax.tree.structure(('*', '*', '*'))
>>> outer_structure = jax.tree.structure(['*', '*'])
>>> jax.tree.transpose(outer_structure, inner_structure, tree)
([1, 4], [2, 5], [3, 6])

推断内部结构

>>> jax.tree.transpose(outer_structure, None, tree)
([1, 4], [2, 5], [3, 6])