jax.tree_util.build_tree#

jax.tree_util.build_tree(treedef, xs)[source]#
Parameters:
  • treedef (PyTreeDef) –

  • xs (Any) –

Return type:

Any