jax.tree_util.treedef_children

jax.tree_util.treedef_children#

jax.tree_util.treedef_children(treedef)[source]#
Parameters:

treedef (PyTreeDef) –

Return type:

list[PyTreeDef]