jax.tree_util.register_dataclass#
- jax.tree_util.register_dataclass(nodetype, data_fields, meta_fields, drop_fields=())[source]#
Extends the set of types that are considered internal nodes in pytrees.
This differs from
register_pytree_with_keys_class
in that the C++ registries use the optimized C++ dataclass builtin instead of the argument functions.See Extending pytrees for more information about registering pytrees.
- Parameters:
nodetype (Typ) – a Python type to treat as an internal pytree node. This is assumed to have the semantics of a
dataclass
: namely, class attributes represent the whole of the object state, and can be passed as keywords to the class constructor to create a copy of the object. All defined attributes should be listed amongmeta_fields
ordata_fields
.meta_fields (Sequence[str]) – auxiliary data field names. These fields must contain static, hashable, immutable objects, as these objects are used to generate JIT cache keys. In particular,
meta_fields
cannot containjax.Array
ornumpy.ndarray
objects.data_fields (Sequence[str]) – data field names. These fields must be JAX-compatible objects such as arrays (
jax.Array
ornumpy.ndarray
), scalars, or pytrees whose leaves are arrays or scalars. Note thatdata_fields
may beNone
, as this is recognized by JAX as an empty pytree.drop_fields (Sequence[str])
- Returns:
The input class
nodetype
is returned unchanged after being added to JAX’s pytree registry. This return value allowsregister_dataclass
to be partially evaluated and used as a decorator as in the example below.- Return type:
Typ
Examples
>>> from dataclasses import dataclass >>> from functools import partial >>> >>> @partial(jax.tree_util.register_dataclass, ... data_fields=['x', 'y'], ... meta_fields=['op']) ... @dataclass ... class MyStruct: ... x: jax.Array ... y: jax.Array ... op: str ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
Now that this class is registered, it can be used with functions in
jax.tree_util
:>>> leaves, treedef = jax.tree.flatten(m) >>> leaves [Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)] >>> treedef PyTreeDef(CustomNode(MyStruct[('add',)], [*, *])) >>> jax.tree.unflatten(treedef, leaves) MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
In particular, this registration allows
m
to be passed seamlessly through code wrapped injax.jit()
and other JAX transformations:>>> @jax.jit ... def compiled_func(m): ... if m.op == 'add': ... return m.x + m.y ... else: ... raise ValueError(f"{m.op=}") ... >>> compiled_func(m) Array([1., 2., 3.], dtype=float32)