jax.tree_util.register_pytree_node_class

jax.tree_util.register_pytree_node_class#

jax.tree_util.register_pytree_node_class(cls)[source]#

Extends the set of types that are considered internal nodes in pytrees.

This function is a thin wrapper around register_pytree_node, and provides a class-oriented interface.

Parameters:

cls (Typ) – a type to register as a pytree

Returns:

The input class cls is returned unchanged after being added to JAX’s pytree registry. This return value allows register_pytree_node_class to be used as a decorator.

Return type:

Typ

See also

Examples

Here we’ll define a custom container that will be compatible with jax.jit() and other JAX transformations:

>>> import jax
>>> @jax.tree_util.register_pytree_node_class
... class MyContainer:
...   def __init__(self, x, y):
...     self.x = x
...     self.y = y
...   def tree_flatten(self):
...     return ((self.x, self.y), None)
...   @classmethod
...   def tree_unflatten(cls, aux_data, children):
...     return cls(*children)
...
>>> m = MyContainer(jnp.zeros(4), jnp.arange(4))
>>> def f(m):
...   return m.x + 2 * m.y
>>> jax.jit(f)(m)
Array([0., 2., 4., 6.], dtype=float32)