jax.tree_util.build_tree

Contents

jax.tree_util.build_tree#

jax.tree_util.build_tree(treedef, xs)[source]#
Parameters:
  • treedef (PyTreeDef)

  • xs (Any)

Return type:

Any