Working with Pytrees

Open in Colab

Author: Vladimir Mikulik

Often, we want to operate on objects that look like dicts of arrays, or lists of lists of dicts, or other nested structures. In JAX, we refer to these as pytrees, but you can sometimes see them called nests, or just trees.

JAX has built-in support for such objects, both in its library functions as well as through the use of functions from jax.tree_utils (with the most common ones also available as jax.tree_*). This section will explain how to use them, give some useful snippets and point out common gotchas.

What is a pytree?

As defined in the JAX pytree docs:

a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.

Some example pytrees:

import jax
import jax.numpy as jnp

example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),

# Let's see how many leaves they have:
for pytree in example_trees:
  leaves = jax.tree_leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
[1, 'a', <object object at 0x7fded60bb8c0>]   has 3 leaves: [1, 'a', <object object at 0x7fded60bb8c0>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
DeviceArray([1, 2, 3], dtype=int32)           has 1 leaves: [DeviceArray([1, 2, 3], dtype=int32)]

We’ve also introduced our first jax.tree_* function, which allowed us to extract the flattened leaves from the trees.

Why pytrees?

In machine learning, some places where you commonly find pytrees are:

  • Model parameters

  • Dataset entries

  • RL agent observations

They also often arise naturally when working in bulk with datasets (e.g., lists of lists of dicts).

Common pytree functions

The most commonly used pytree functions are jax.tree_map and jax.tree_multimap. They work analogously to Python’s native map, but on entire pytrees.

For functions with one argument, use jax.tree_map:

list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]

jax.tree_map(lambda x: x*2, list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

To use functions with more than one argument, use jax.tree_multimap:

another_list_of_lists = list_of_lists
jax.tree_multimap(lambda x, y: x+y, list_of_lists, another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

For tree_multimap, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc.

Example: ML model parameters

A simple example of training an MLP displays some ways in which pytree operations come in useful:

import numpy as np

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
  return params

params = init_mlp_params([1, 128, 128, 1])

We can use jax.tree_map to check that the shapes of our parameters are what we expect:

jax.tree_map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]

Now, let’s train our MLP:

def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)


def update(params, x, y):

  grads = jax.grad(loss_fn)(params, x, y)
  # Note that `grads` is a pytree with the same structure as `params`.
  # `jax.grad` is one of the many JAX functions that has
  # built-in support for pytrees.

  # This is handy, because we can apply the SGD update using tree utils:
  return jax.tree_multimap(
      lambda p, g: p - LEARNING_RATE * g, params, grads
import matplotlib.pyplot as plt

xs = np.random.normal(size=(128, 1))
ys = xs ** 2

for _ in range(1000):
  params = update(params, xs, ys)

plt.scatter(xs, ys)
plt.scatter(xs, forward(params, xs), label='Model prediction')

Custom pytree nodes

So far, we’ve only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define your own container class, it will be considered a leaf, even if it has trees inside it:

class MyContainer:
  """A named container."""

  def __init__(self, name: str, a: int, b: int, c: int): = name
    self.a = a
    self.b = b
    self.c = c
    MyContainer('Alice', 1, 2, 3),
    MyContainer('Bob', 4, 5, 6)
[<__main__.MyContainer at 0x7fdec166ce50>,
 <__main__.MyContainer at 0x7fded89ba490>]

Accordingly, if we try to use a tree map expecting our leaves to be the elements inside the container, we will get an error:

jax.tree_map(lambda x: x + 1, [
    MyContainer('Alice', 1, 2, 3),
    MyContainer('Bob', 4, 5, 6)
TypeError                                 Traceback (most recent call last)
<ipython-input-10-d6b45a2ec2b9> in <module>()
      1 jax.tree_map(lambda x: x + 1, [
      2     MyContainer('Alice', 1, 2, 3),
----> 3     MyContainer('Bob', 4, 5, 6)
      4 ])

/usr/local/lib/python3.7/dist-packages/jax/ in tree_map(f, tree, is_leaf)
    184   """
    185   leaves, treedef = tree_flatten(tree, is_leaf)
--> 186   return treedef.unflatten(map(f, leaves))
    188 def tree_multimap(f: Callable[..., Any], tree: Any, *rest: Any,

<ipython-input-10-d6b45a2ec2b9> in <lambda>(x)
----> 1 jax.tree_map(lambda x: x + 1, [
      2     MyContainer('Alice', 1, 2, 3),
      3     MyContainer('Bob', 4, 5, 6)
      4 ])

TypeError: unsupported operand type(s) for +: 'MyContainer' and 'int'

To solve this, we need to register our container with JAX by telling it how to flatten and unflatten it:

from typing import Tuple, Iterable

def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:
  """Returns an iterable over container contents, and aux data."""
  flat_contents = [container.a, container.b, container.c]

  # we don't want the name to appear as a child, so it is auxiliary data.
  # auxiliary data is usually a description of the structure of a node,
  # e.g., the keys of a dict -- anything that isn't a node's children.
  aux_data =
  return flat_contents, aux_data

def unflatten_MyContainer(
    aux_data: str, flat_contents: Iterable[int]) -> MyContainer:
  """Converts aux data and the flat contents into a MyContainer."""
  return MyContainer(aux_data, *flat_contents)

    MyContainer, flatten_MyContainer, unflatten_MyContainer)

    MyContainer('Alice', 1, 2, 3),
    MyContainer('Bob', 4, 5, 6)
[1, 2, 3, 4, 5, 6]

Modern Python comes equipped with helpful tools to make defining containers easier. Some of these will work with JAX out-of-the-box, but others require more care. For instance:

from typing import NamedTuple, Any

class MyOtherContainer(NamedTuple):
  name: str
  a: Any
  b: Any
  c: Any

# Since `tuple` is already registered with JAX, and NamedTuple is a subclass,
# this will work out-of-the-box:
    MyOtherContainer('Alice', 1, 2, 3),
    MyOtherContainer('Bob', 4, 5, 6)
['Alice', 1, 2, 3, 'Bob', 4, 5, 6]

Notice that the name field now appears as a leaf, as all tuple elements are children. That’s the price we pay for not having to register the class the hard way.

Common pytree gotchas and patterns


Mistaking nodes for leaves

A common problem to look out for is accidentally introducing tree nodes instead of leaves:

a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]

# Try to make another tree with ones instead of zeros
shapes = jax.tree_map(lambda x: x.shape, a_tree)
jax.tree_map(jnp.ones, shapes)
[(DeviceArray([1., 1.], dtype=float32),
  DeviceArray([1., 1., 1.], dtype=float32)),
 (DeviceArray([1., 1., 1.], dtype=float32),
  DeviceArray([1., 1., 1., 1.], dtype=float32))]

What happened is that the shape of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling jnp.ones on e.g. (2, 3), it’s called on 2 and 3.

The solution will depend on the specifics, but there are two broadly applicable options:

  • rewrite the code to avoid the intermediate tree_map.

  • convert the tuple into an np.array or jnp.array, which makes the entire sequence a leaf.

Handling of None

jax.tree_utils treats None as a node without children, not as a leaf:

jax.tree_leaves([None, None, None])


Transposing trees

If you would like to transpose a pytree, i.e. turn a list of trees into a tree of lists, you can do so using jax.tree_multimap:

def tree_transpose(list_of_trees):
  """Convert a list of trees of identical structure into a single tree of lists."""
  return jax.tree_multimap(lambda *xs: list(xs), *list_of_trees)

# Convert a dataset from row-major to column-major:
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
{'obs': [3, 4], 't': [1, 2]}

For more complicated transposes, JAX provides jax.tree_transpose, which is more verbose, but allows you specify the structure of the inner and outer Pytree for more flexibility:

  outer_treedef = jax.tree_structure([0 for e in episode_steps]),
  inner_treedef = jax.tree_structure(episode_steps[0]),
  pytree_to_transpose = episode_steps
{'obs': [3, 4], 't': [1, 2]}

More Information

For more information on pytrees in JAX and the operations that are available, see the Pytrees section in the JAX documentation.