jax.tree_util.tree_unflatten#

jax.tree_util.tree_unflatten(treedef, leaves)[source]#

Reconstructs a pytree from the treedef and the leaves.

The inverse of tree_flatten().

Parameters
  • treedef (PyTreeDef) – the treedef to reconstruct

  • leaves (Iterable[Any]) – the iterable of leaves to use for reconstruction. The iterable must match the leaves of the treedef.

Return type

Any

Returns

The reconstructed pytree, containing the leaves placed in the structure described by treedef.