jax.tree_util.tree_transpose

Contents

jax.tree_util.tree_transpose#

jax.tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)[source]#

Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).

Parameters:
  • outer_treedef (PyTreeDef) – PyTreeDef representing the outer tree.

  • inner_treedef (PyTreeDef | None) – PyTreeDef representing the inner tree. If None, then it will be inferred from outer_treedef and the structure of pytree_to_transpose.

  • pytree_to_transpose (Any) – the pytree to be transposed.

Returns:

the transposed pytree.

Return type:

transposed_pytree

Examples

>>> import jax
>>> tree = [(1, 2, 3), (4, 5, 6)]
>>> inner_structure = jax.tree.structure(('*', '*', '*'))
>>> outer_structure = jax.tree.structure(['*', '*'])
>>> jax.tree.transpose(outer_structure, inner_structure, tree)
([1, 4], [2, 5], [3, 6])

Inferring the inner structure:

>>> jax.tree.transpose(outer_structure, None, tree)
([1, 4], [2, 5], [3, 6])