jax.tree_util.register_pytree_with_keys

jax.tree_util.register_pytree_with_keys#

jax.tree_util.register_pytree_with_keys(nodetype, flatten_with_keys, unflatten_func, flatten_func=None)[source]#

Extends the set of types that are considered internal nodes in pytrees.

This is a more powerful alternative to register_pytree_node that allows you to access each pytree leaf’s key path when flattening and tree-mapping.

Parameters:
  • nodetype (type[T]) – a Python type to treat as an internal pytree node.

  • flatten_with_keys (Callable[[T], tuple[Iterable[tuple[KeyEntry, Any]], _AuxData]]) – a function to be used during flattening, taking a value of type nodetype and returning a pair, with (1) an iterable for tuples of each key path and its child, and (2) some hashable auxiliary data to be stored in the treedef and to be passed to the unflatten_func.

  • unflatten_func (Callable[[_AuxData, Iterable[Any]], T]) – a function taking two arguments: the auxiliary data that was returned by flatten_func and stored in the treedef, and the unflattened children. The function should return an instance of nodetype.

  • flatten_func (None | Callable[[T], tuple[Iterable[Any], _AuxData]]) – an optional function similar to flatten_with_keys, but returns only children and auxiliary data. It must return the children in the same order as flatten_with_keys, and return the same aux data. This argument is optional and only needed for faster traversal when calling functions without keys like tree_map and tree_flatten.