jax.tree_util.register_dataclass(nodetype, data_fields, meta_fields)[source]#

Extends the set of types that are considered internal nodes in pytrees.

This differs from register_pytree_with_keys_class in that the C++ registries use the optimized C++ dataclass builtin instead of the argument functions.

See Extending pytrees for more information about registering pytrees.

  • nodetype (Typ) – a Python type to treat as an internal pytree node. This is assumed to have the semantics of a dataclass: namely, class attributes represent the whole of the object state, and can be passed as keywords to the class constructor to create a copy of the object. All defined attributes should be listed among meta_fields or data_fields.

  • meta_fields (Sequence[str]) – auxiliary data field names. These fields must contain static, hashable, immutable objects, as these objects are used to generate JIT cache keys. In particular, meta_fields cannot contain jax.Array or numpy.ndarray objects.

  • data_fields (Sequence[str]) – data field names. These fields must be JAX-compatible objects such as arrays (jax.Array or numpy.ndarray), scalars, or pytrees whose leaves are arrays or scalars. Note that data_fields may be None, as this is recognized by JAX as an empty pytree.


The input class nodetype is returned unchanged after being added to JAX’s pytree registry. This return value allows register_dataclass to be partially evaluated and used as a decorator as in the example below.

Return type:



>>> from dataclasses import dataclass
>>> from functools import partial
>>> @partial(jax.tree_util.register_dataclass,
...          data_fields=['x', 'y'],
...          meta_fields=['op'])
... @dataclass
... class MyStruct:
...   x: jax.Array
...   y: jax.Array
...   op: str
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
>>> m
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

Now that this class is registered, it can be used with functions in jax.tree_util:

>>> leaves, treedef = jax.tree.flatten(m)
>>> leaves
[Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)]
>>> treedef
PyTreeDef(CustomNode(MyStruct[('add',)], [*, *]))
>>> jax.tree.unflatten(treedef, leaves)
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

In particular, this registration allows m to be passed seamlessly through code wrapped in jax.jit() and other JAX transformations:

>>> @jax.jit
... def compiled_func(m):
...   if m.op == 'add':
...     return m.x + m.y
...   else:
...     raise ValueError(f"{m.op=}")
>>> compiled_func(m)
Array([1., 2., 3.], dtype=float32)