Interactive online version: Open In Colab

JAX pytrees

Date: October 2019

This is primarily JAX internal documentation, end-users are not supposed to need to understand this to use JAX, except when registering new user-defined container types with JAX. Some of these details may change.

Python has a lot of container data types (list, tuple, dict, namedtuple, etc.), and users sometimes define their own. To keep the JAX internals simpler while supporting lots of container types, we canonicalize nested containers into flat lists of numeric or array types at the boundary (and also in control flow primitives). That way grad, jit, vmap etc., can handle user functions that accept and return these containers, while all the other parts of the system can operate on functions that only take (multiple) array arguments and always return a flat list of arrays.

We refer to a recursive structured value whose leaves are basic types as a pytree. When JAX flattens a pytree it will produce a list of leaves and a treedef object that encodes the structure of the original value. The treedef can then be used to construct a matching structured value after transforming the leaves. Pytrees are tree-like, rather than DAG-like or graph-like, in that we handle them assuming referential transparency and that they can’t contain reference cycles.

Here is a simple example:

from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
from jax import numpy as np

# The structured value to be transformed
value_structured = [1., (2., 3.)]

# The leaves in value_flat correspond to the `*` markers in value_tree
value_flat, value_tree = tree_flatten(value_structured)
print("value_flat={}\nvalue_tree={}".format(value_flat, value_tree))

# Transform the flt value list using an element-wise numeric transformer
transformed_flat = list(map(lambda v: v * 2., value_flat))

# Reconstruct the structured output, using the original
transformed_structured = tree_unflatten(value_tree, transformed_flat)
value_flat=[1.0, 2.0, 3.0]
value_tree=PyTreeDef(list, [*,PyTreeDef(tuple, [*,*])])
transformed_flat=[2.0, 4.0, 6.0]
transformed_structured=[2.0, (4.0, 6.0)]

Pytrees containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves:

from collections import namedtuple
Point = namedtuple('Point', ['x', 'y'])

example_containers = [
    (1., [2., 3.]),
    (1., {'b': 2., 'a': 3.}),
    Point(1., 2.)
def show_example(structured):
  flat, tree = tree_flatten(structured)
  unflattened = tree_unflatten(tree, flat)
  print("structured={}\n  flat={}\n  tree={}\n  unflattened={}".format(
      structured, flat, tree, unflattened))

for structured in example_containers:
structured=(1.0, [2.0, 3.0])
  flat=[1.0, 2.0, 3.0]
  tree=PyTreeDef(tuple, [*,PyTreeDef(list, [*,*])])
  unflattened=(1.0, [2.0, 3.0])
structured=(1.0, {'b': 2.0, 'a': 3.0})
  flat=[1.0, 3.0, 2.0]
  tree=PyTreeDef(tuple, [*,PyTreeDef(dict[['a', 'b']], [*,*])])
  unflattened=(1.0, {'a': 3.0, 'b': 2.0})
  tree=PyTreeDef(None, [])
structured=[0. 0.]
  flat=[DeviceArray([0., 0.], dtype=float32)]
  unflattened=[0. 0.]
structured=Point(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(namedtuple[<class '__main__.Point'>], [*,*])
  unflattened=Point(x=1.0, y=2.0)
/home/docs/checkouts/ UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')

Pytrees are extensible

By default, any part of a structured value that is not recognized as an internal pytree node is treated as a leaf (and such containers could not be passed to JAX-traceable functions):

class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

  def __repr__(self):
    return "Special(x={}, y={})".format(self.x, self.y)

show_example(Special(1., 2.))
structured=Special(x=1.0, y=2.0)
  flat=[Special(x=1.0, y=2.0)]
  unflattened=Special(x=1.0, y=2.0)

The set of Python types that are considered internal pytree nodes is extensible, through a global registry of types. Values of registered types are traversed recursively:

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

    v: the value of registered type to flatten.
    a pair of an iterable with the children to be flattened recursively,
    and some opaque auxiliary data to pass back to the unflattening recipe.
    The auxiliary data is stored in the treedef for use during unflattening.
    The auxiliary data could be used, e.g., for dictionary keys.
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

    aux_data: the opaque data that was specified during flattening of the
      current treedef.
    children: the unflattened children

    a re-constructed object of the registered type, using the specified
    children and auxiliary data.
  return RegisteredSpecial(*children)

# Global registration
    special_flatten,    # tell JAX what are the children nodes
    special_unflatten   # tell JAX how to pack back into a RegisteredSpecial

show_example(RegisteredSpecial(1., 2.))
structured=RegisteredSpecial(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(<class '__main__.RegisteredSpecial'>[None], [*,*])
  unflattened=RegisteredSpecial(x=1.0, y=2.0)

JAX needs sometimes to compare treedef for equality. Therefore care must be taken to ensure that the auxiliary data specified in the flattening recipe supports a meaningful equality comparison.

The whole set of functions for operating on pytrees are in the tree_util module.