jax.tree_util.tree_leaves

jax.tree_util.tree_leaves(tree)[source]

Gets the leaves of a pytree.