jax.tree_util.register_pytree_with_keys_class#
- jax.tree_util.register_pytree_with_keys_class(cls)[source]#
Extends the set of types that are considered internal nodes in pytrees.
This function is similar to
register_pytree_node_class
, but requires a class that defines how it could be flattened with keys.It is a thin wrapper around
register_pytree_with_keys
, and provides a class-oriented interface:- Parameters:
cls (Typ) – a type to register as a pytree
- Returns:
The input class
cls
is returned unchanged after being added to JAX’s pytree registry. This return value allowsregister_pytree_node_class
to be used as a decorator.- Return type:
Typ
See also
register_static()
: simpler API for registering a static pytree.register_dataclass()
: simpler API for registering a dataclass.
Examples
>>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey >>> @register_pytree_with_keys_class ... class Special: ... def __init__(self, x, y): ... self.x = x ... self.y = y ... def tree_flatten_with_keys(self): ... return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None) ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children)