jax.tree_util.register_pytree_node#
- jax.tree_util.register_pytree_node(nodetype, flatten_func, unflatten_func)[source]#
Extends the set of types that are considered internal nodes in pytrees.
See example usage.
- Parameters:
nodetype (
Type
[TypeVar
(T
)]) – a Python type to treat as an internal pytree node.flatten_func (
Callable
[[TypeVar
(T
)],Tuple
[TypeVar
(_Children
, bound=Iterable
[Any
]),TypeVar
(_AuxData
, bound=Hashable
)]]) – a function to be used during flattening, taking a value of typenodetype
and returning a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable auxiliary data to be stored in the treedef and to be passed to theunflatten_func
.unflatten_func (
Callable
[[TypeVar
(_AuxData
, bound=Hashable
),TypeVar
(_Children
, bound=Iterable
[Any
])],TypeVar
(T
)]) – a function taking two arguments: the auxiliary data that was returned byflatten_func
and stored in the treedef, and the unflattened children. The function should return an instance ofnodetype
.