jax.tree_util.tree_all

jax.tree_util.tree_all(tree)[source]