jax.tree_util.treedef_tuple

jax.tree_util.treedef_tuple(treedefs)[source]

Makes a tuple treedef from a list of child treedefs.