jax.tree_util.treedef_children# jax.tree_util.treedef_children(treedef)[source]# Parameters: treedef (PyTreeDef) – Return type: List[PyTreeDef]