jax.tree_util.tree_all

Contents

jax.tree_util.tree_all#

jax.tree_util.tree_all(tree)[source]#

Call all() over the leaves of a tree.

Parameters:

tree (Any) – the pytree to evaluate

Returns:

boolean True or False

Return type:

result

Examples

>>> import jax
>>> jax.tree.all([True, {'a': True, 'b': (True, True)}])
True
>>> jax.tree.all([False, (True, False)])
False