jax.tree_util.build_tree

jax.tree_util.build_tree(treedef, xs)[source]