jax.tree_util.treedef_tuple#

jax.tree_util.treedef_tuple(treedefs)[source]#

Makes a tuple treedef from an iterable of child treedefs.

Parameters

treedefs (Iterable[PyTreeDef]) –

Return type

PyTreeDef