jax.tree_util.register_pytree_with_keys#
- jax.tree_util.register_pytree_with_keys(nodetype, flatten_with_keys, unflatten_func, flatten_func=None)[source]#
Extends the set of types that are considered internal nodes in pytrees.
This is a more powerful alternative to
register_pytree_node
that allows you to access each pytree leaf’s key path when flattening and tree-mapping.- Parameters:
nodetype (type[T]) – a Python type to treat as an internal pytree node.
flatten_with_keys (Callable[[T], tuple[Iterable[tuple[KeyEntry, Any]], _AuxData]]) – a function to be used during flattening, taking a value of type
nodetype
and returning a pair, with (1) an iterable for tuples of each key path and its child, and (2) some hashable auxiliary data to be stored in the treedef and to be passed to theunflatten_func
.unflatten_func (Callable[[_AuxData, Iterable[Any]], T]) – a function taking two arguments: the auxiliary data that was returned by
flatten_func
and stored in the treedef, and the unflattened children. The function should return an instance ofnodetype
.flatten_func (None | Callable[[T], tuple[Iterable[Any], _AuxData]]) – an optional function similar to
flatten_with_keys
, but returns only children and auxiliary data. It must return the children in the same order asflatten_with_keys
, and return the same aux data. This argument is optional and only needed for faster traversal when calling functions without keys liketree_map
andtree_flatten
.