jax.tree_util package

Utilities for working with tree-like container data structures.

This module provides a small set of utility functions for working with tree-like data structures, such as nested tuples, lists, and dicts. We call these structures pytrees. They are trees in that they are defined recursively (any non-pytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) and can be operated on recursively (object identity equivalence is not preserved by mapping operations, and the structures cannot contain reference cycles).

The set of Python types that are considered pytree nodes (e.g. that can be mapped over, rather than treated as leaves) is extensible. There is a single module-level registry of types, and class hierarchy is ignored. By registering a new pytree node type, that type in effect becomes transparent to the utility functions in this file.

The primary purpose of this module is to enable the interoperability between user defined data structures and JAX transformations (e.g. jit). This is not meant to be a general purpose tree-like data structure handling library.

See the JAX pytrees note for examples.

class jax.tree_util.Partial[source]

A version of functools.partial that works in pytrees.

Use it for partial function evaluation in a way that is compatible with JAX’s transformations, e.g., Partial(func, *args, **kwargs).

(You need to explicitly opt-in to this behavior because we didn’t want to give functools.partial different semantics than normal function closures.)

jax.tree_util.all_leaves(iterable)[source]

Tests whether all elements in the given iterable are all leaves.

>>> tree = {"a": [1, 2, 3]}
>>> assert all_leaves(jax.tree_leaves(tree))
>>> assert not all_leaves([tree])

This function is useful in advanced cases, for example if a library allows arbitrary map operations on a flat list of leaves it may want to check if the result is still a flat list of leaves.

Parameters

iterable – Iterable of leaves.

Returns

A boolean indicating if all elements in the input are leaves.

jax.tree_util.build_tree(treedef, xs)[source]
jax.tree_util.register_pytree_node(nodetype, flatten_func, unflatten_func)[source]

Extends the set of types that are considered internal nodes in pytrees.

See example usage.

Parameters
  • nodetype (Type[~T]) – a Python type to treat as an internal pytree node.

  • flatten_func (Callable[[~T], Tuple[Sequence[Any], Any]]) – a function to be used during flattening, taking a value of type nodetype and returning a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable auxiliary data to be stored in the treedef and to be passed to the unflatten_func.

  • unflatten_func (Callable[[Any, Sequence[Any]], ~T]) – a function taking two arguments: the auxiliary data that was returned by flatten_func and stored in the treedef, and the unflattened children. The function should return an instance of nodetype.

jax.tree_util.register_pytree_node_class(cls)[source]

Extends the set of types that are considered internal nodes in pytrees.

This function is a thin wrapper around register_pytree_node, and provides a class-oriented interface:

@register_pytree_node_class class Special:

def __init__(self, x, y):

self.x = x self.y = y

def tree_flatten(self):

return ((self.x, self.y), None)

@classmethod def tree_unflatten(cls, aux_data, children):

return cls(*children)

jax.tree_util.tree_all(tree)[source]
jax.tree_util.tree_flatten(tree, is_leaf=None)[source]

Flattens a pytree.

Parameters
  • tree – a pytree to flatten.

  • is_leaf (Optional[Callable[[Any], bool]]) – an optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.

Returns

A pair where the first element is a list of leaf values and the second element is a treedef representing the structure of the flattened tree.

jax.tree_util.tree_leaves(tree)[source]

Gets the leaves of a pytree.

jax.tree_util.tree_map(f, tree)[source]

Maps a function over a pytree to produce a new pytree.

Parameters
  • f (Callable[[Any], Any]) – unary function to be applied at each leaf.

  • tree (Any) – a pytree to be mapped over.

Return type

Any

Returns

A new pytree with the same structure as tree but with the value at each leaf given by f(x) where x is the value at the corresponding leaf in the input tree.

jax.tree_util.tree_multimap(f, tree, *rest)[source]

Maps a multi-input function over pytree args to produce a new pytree.

Parameters
  • f (Callable[…, Any]) – function that takes 1 + len(rest) arguments, to be applied at the corresponding leaves of the pytrees.

  • tree (Any) – a pytree to be mapped over, with each leaf providing the first positional argument to f.

  • *rest – a tuple of pytrees, each of which has the same structure as tree or or has tree as a prefix.

  • rest (Any) –

Return type

Any

Returns

A new pytree with the same structure as tree but with the value at each leaf given by f(x, *xs) where x is the value at the corresponding leaf in tree and xs is the tuple of values at corresponding nodes in rest.

jax.tree_util.tree_reduce(function, tree, initializer=<object object>)[source]
Parameters
Return type

~T

jax.tree_util.tree_structure(tree)[source]

Gets the treedef for a pytree.

jax.tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)[source]
jax.tree_util.tree_unflatten(treedef, leaves)[source]

Reconstructs a pytree from the treedef and the leaves.

The inverse of tree_flatten().

Parameters
  • treedef – the treedef to reconstruct

  • leaves – the list of leaves to use for reconstruction. The list must match the leaves of the treedef.

Returns

The reconstructed pytree, containing the leaves placed in the structure described by treedef.

jax.tree_util.treedef_children(treedef)[source]
jax.tree_util.treedef_is_leaf(treedef)[source]
jax.tree_util.treedef_tuple(treedefs)[source]

Makes a tuple treedef from a list of child treedefs.