jax.tree_util.treedef_is_leaf#
- jax.tree_util.treedef_is_leaf(treedef)[source]#
Return True if the treedef represents a leaf.
- Parameters:
treedef (PyTreeDef) – tree to check
- Returns:
True if treedef is a leaf (i.e. has a single node); False otherwise.
- Return type:
Examples
>>> import jax >>> tree1 = jax.tree.structure(1) >>> jax.tree_util.treedef_is_leaf(tree1) True >>> tree2 = jax.tree.structure([1, 2]) >>> jax.tree_util.treedef_is_leaf(tree2) False