jax.tree_util.all_leaves

Contents

jax.tree_util.all_leaves#

jax.tree_util.all_leaves(iterable, is_leaf=None)[source]#

Tests whether all elements in the given iterable are all leaves.

This function is useful in advanced cases, for example if a library allows arbitrary map operations on a flat iterable of leaves it may want to check if the result is still a flat iterable of leaves.

Parameters:
Returns:

A boolean indicating if all elements in the input are leaves.

Return type:

bool

Examples

>>> import jax
>>> tree = {"a": [1, 2, 3]}
>>> assert all_leaves(jax.tree_util.tree_leaves(tree))
>>> assert not all_leaves([tree])