jax.tree_util.treedef_is_leaf

jax.tree_util.treedef_is_leaf#

jax.tree_util.treedef_is_leaf(treedef)[source]#
Parameters:

treedef (PyTreeDef)

Return type:

bool