- jax.tree_util.register_pytree_node(nodetype, flatten_func, unflatten_func)#
Extends the set of types that are considered internal nodes in pytrees.
See example usage.
Hashable)]]) – a function to be used during flattening, taking a value of type
nodetypeand returning a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable auxiliary data to be stored in the treedef and to be passed to the
T)]) – a function taking two arguments: the auxiliary data that was returned by
flatten_funcand stored in the treedef, and the unflattened children. The function should return an instance of