jax.tree_util.all_leaves

Contents

jax.tree_util.all_leaves#

jax.tree_util.all_leaves(iterable, is_leaf=None)[source]#

Tests whether all elements in the given iterable are all leaves.

>>> tree = {"a": [1, 2, 3]}
>>> assert all_leaves(jax.tree_util.tree_leaves(tree))
>>> assert not all_leaves([tree])

This function is useful in advanced cases, for example if a library allows arbitrary map operations on a flat iterable of leaves it may want to check if the result is still a flat iterable of leaves.

Parameters:
  • iterable (Iterable[Any]) – Iterable of leaves.

  • is_leaf (Callable[[Any], bool] | None)

Return type:

bool

Returns:

A boolean indicating if all elements in the input are leaves.