jax.tree_util.treedef_children# jax.tree_util.treedef_children(treedef)[source]# Parameters: treedef (PyTreeDef) – Return type: list[PyTreeDef]