jax.tree_util.all_leaves

jax.tree_util.all_leaves(iterable)[source]

Tests whether all elements in the given iterable are all leaves.

>>> tree = {"a": [1, 2, 3]}
>>> assert all_leaves(jax.tree_leaves(tree))
>>> assert not all_leaves([tree])

This function is useful in advanced cases, for example if a library allows arbitrary map operations on a flat list of leaves it may want to check if the result is still a flat list of leaves.

Parameters

iterable – Iterable of leaves.

Returns

A boolean indicating if all elements in the input are leaves.