jax.tree_util.build_tree#

jax.tree_util.build_tree(treedef, xs)[source]#

Build a treedef from a nested iterable structure

Parameters:
  • treedef (PyTreeDef) – the PyTreeDef structure to build.

  • xs (Any) – nested iterables matching the arity as the treedef

Returns:

object with structure defined by treedef

Return type:

Any

Examples

>>> import jax
>>> tree = [(1, 2), {'a': 3, 'b': 4}]
>>> treedef = jax.tree.structure(tree)

Both build_tree and jax.tree_util.tree_unflatten() can reconstruct the tree from new values, but build_tree takes these values in terms of a nested rather than flat structure:

>>> jax.tree_util.build_tree(treedef, [[10, 11], [12, 13]])
[(10, 11), {'a': 12, 'b': 13}]
>>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13])
[(10, 11), {'a': 12, 'b': 13}]