jax.tree_util.build_tree#

jax.tree_util.build_tree(treedef, xs)[source]#
Parameters
  • treedef (PyTreeDef) –

  • xs (Any) –

Return type

Any