jax.tree_util.register_pytree_node_class

jax.tree_util.register_pytree_node_class#

jax.tree_util.register_pytree_node_class(cls)[source]#

Extends the set of types that are considered internal nodes in pytrees.

This function is a thin wrapper around register_pytree_node, and provides a class-oriented interface:

@register_pytree_node_class
class Special:
  def __init__(self, x, y):
    self.x = x
    self.y = y
  def tree_flatten(self):
    return ((self.x, self.y), None)
  @classmethod
  def tree_unflatten(cls, aux_data, children):
    return cls(*children)
Parameters:

cls (TypeVar(U, bound= type[Any]))

Return type:

TypeVar(U, bound= type[Any])