jax.tree_util.build_tree#
- jax.tree_util.build_tree(treedef, xs)[source]#
Build a treedef from a nested iterable structure
- Parameters:
treedef (PyTreeDef) – the PyTreeDef structure to build.
xs (Any) – nested iterables matching the arity as the treedef
- Returns:
object with structure defined by treedef
- Return type:
Any
See also
Examples
>>> import jax >>> tree = [(1, 2), {'a': 3, 'b': 4}] >>> treedef = jax.tree.structure(tree)
Both
build_tree
andjax.tree_util.tree_unflatten()
can reconstruct the tree from new values, butbuild_tree
takes these values in terms of a nested rather than flat structure:>>> jax.tree_util.build_tree(treedef, [[10, 11], [12, 13]]) [(10, 11), {'a': 12, 'b': 13}] >>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13]) [(10, 11), {'a': 12, 'b': 13}]