jax.tree_util.register_pytree_node

jax.tree_util.register_pytree_node#

jax.tree_util.register_pytree_node(nodetype, flatten_func, unflatten_func)[source]#

Extends the set of types that are considered internal nodes in pytrees.

See example usage.

Parameters:
  • nodetype (type[TypeVar(T)]) – a Python type to treat as an internal pytree node.

  • flatten_func (Callable[[TypeVar(T)], tuple[TypeVar(_Children, bound= Iterable[Any]), TypeVar(_AuxData, bound= Hashable)]]) – a function to be used during flattening, taking a value of type nodetype and returning a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable auxiliary data to be stored in the treedef and to be passed to the unflatten_func.

  • unflatten_func (Callable[[TypeVar(_AuxData, bound= Hashable), TypeVar(_Children, bound= Iterable[Any])], TypeVar(T)]) – a function taking two arguments: the auxiliary data that was returned by flatten_func and stored in the treedef, and the unflattened children. The function should return an instance of nodetype.