jax.tree.unflatten

Contents

jax.tree.unflatten#

jax.tree.unflatten(treedef, leaves)#

Reconstructs a pytree from the treedef and the leaves.

Alias of jax.tree_util.tree_unflatten().

The inverse of tree_flatten().

Parameters:
  • treedef (PyTreeDef) – the treedef to reconstruct

  • leaves (Iterable[Any]) – the iterable of leaves to use for reconstruction. The iterable must match the leaves of the treedef.

Return type:

Any

Returns:

The reconstructed pytree, containing the leaves placed in the structure described by treedef.

Example

>>> import jax
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
>>> newvals = [100, 200, 300, 400, 500]
>>> jax.tree.unflatten(treedef, newvals)
[100, (200, 300), [400, 500]]