jax.tree_util.tree_unflatten

jax.tree_util.tree_unflatten(treedef, leaves)[source]

Reconstructs a pytree from the treedef and the leaves.

The inverse of tree_flatten().

Parameters
  • treedef – the treedef to reconstruct

  • leaves – the list of leaves to use for reconstruction. The list must match the leaves of the treedef.

Returns

The reconstructed pytree, containing the leaves placed in the structure described by treedef.