jax.tree_util.tree_transpose#

jax.tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)[source]#

Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).

Parameters
  • outer_treedef (PyTreeDef) –

  • inner_treedef (PyTreeDef) –

  • pytree_to_transpose (Any) –

Return type

Any