jax.tree_util.tree_transpose# jax.tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)[source]# Transform a tree having tree structure (outer, inner) into one having structure (inner, outer). Parameters: outer_treedef (PyTreeDef) – inner_treedef (PyTreeDef) – pytree_to_transpose (Any) – Return type: Any