jax.tree_util.tree_structure

jax.tree_util.tree_structure(tree)[source]

Gets the treedef for a pytree.