jax.tree.all

Contents

jax.tree.all#

jax.tree.all(tree, *, is_leaf=None)[source]#

Call all() over the leaves of a tree.

Parameters:
  • tree (Any) – the pytree to evaluate

  • is_leaf (Callable[[Any], bool] | None) – an optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.

Returns:

boolean True or False

Return type:

result

Examples

>>> import jax
>>> jax.tree.all([True, {'a': True, 'b': (True, True)}])
True
>>> jax.tree.all([False, (True, False)])
False