jax.tree_util.treedef_tuple

Contents

jax.tree_util.treedef_tuple#

jax.tree_util.treedef_tuple(treedefs)[source]#

Makes a tuple treedef from an iterable of child treedefs.

Parameters:

treedefs (Iterable[PyTreeDef])

Return type:

PyTreeDef