jax.tree_util.tree_transpose#
- jax.tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)[source]#
Alias of
jax.tree.transpose()
.- Parameters:
outer_treedef (PyTreeDef)
inner_treedef (PyTreeDef | None)
pytree_to_transpose (Any)
- Return type:
Any