jax.tree_util.tree_transpose

jax.tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)[source]