jax.tree_util.register_dataclass

jax.tree_util.register_dataclass#

jax.tree_util.register_dataclass(nodetype, data_fields, meta_fields)[source]#

Extends the set of types that are considered internal nodes in pytrees.

This differs from register_pytree_with_keys_class in that the C++ registries use the optimized C++ dataclass builtin instead of the argument functions.

See Extending pytrees for more information about registering pytrees.

Parameters:
  • nodetype (Typ) – a Python type to treat as an internal pytree node. This is assumed to have the semantics of a dataclass: namely, class attributes represent the whole of the object state, and can be passed as keywords to the class constructor to create a copy of the object. All defined attributes should be listed among meta_fields or data_fields.

  • meta_fields (Sequence[str]) – auxiliary data field names. These fields must contain static, hashable, immutable objects, as these objects are used to generate JIT cache keys. In particular, meta_fields cannot contain jax.Array or numpy.ndarray objects.

  • data_fields (Sequence[str]) – data field names. These fields must be JAX-compatible objects such as arrays (jax.Array or numpy.ndarray), scalars, or pytrees whose leaves are arrays or scalars. Note that data_fields may be None, as this is recognized by JAX as an empty pytree.

Returns:

The input class nodetype is returned unchanged after being added to JAX’s pytree registry. This return value allows register_dataclass to be partially evaluated and used as a decorator as in the example below.

Return type:

Typ

Examples

>>> from dataclasses import dataclass
>>> from functools import partial
>>>
>>> @partial(jax.tree_util.register_dataclass,
...          data_fields=['x', 'y'],
...          meta_fields=['op'])
... @dataclass
... class MyStruct:
...   x: jax.Array
...   y: jax.Array
...   op: str
...
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
>>> m
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

Now that this class is registered, it can be used with functions in jax.tree_util:

>>> leaves, treedef = jax.tree.flatten(m)
>>> leaves
[Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)]
>>> treedef
PyTreeDef(CustomNode(MyStruct[('add',)], [*, *]))
>>> jax.tree.unflatten(treedef, leaves)
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

In particular, this registration allows m to be passed seamlessly through code wrapped in jax.jit() and other JAX transformations:

>>> @jax.jit
... def compiled_func(m):
...   if m.op == 'add':
...     return m.x + m.y
...   else:
...     raise ValueError(f"{m.op=}")
...
>>> compiled_func(m)
Array([1., 2., 3.], dtype=float32)