jax.tree_util.treedef_is_leaf#

jax.tree_util.treedef_is_leaf(treedef)[source]#

Return True if the treedef represents a leaf.

Parameters:

treedef (PyTreeDef) – tree to check

Returns:

True if treedef is a leaf (i.e. has a single node); False otherwise.

Return type:

bool

Examples

>>> import jax
>>> tree1 = jax.tree.structure(1)
>>> jax.tree_util.treedef_is_leaf(tree1)
True
>>> tree2 = jax.tree.structure([1, 2])
>>> jax.tree_util.treedef_is_leaf(tree2)
False