jax.tree_util.treedef_tuple# jax.tree_util.treedef_tuple(treedefs)[source]# Makes a tuple treedef from an iterable of child treedefs. Parameters: treedefs (Iterable[PyTreeDef]) Return type: PyTreeDef