jax.tree_util module#

Utilities for working with tree-like container data structures.

This module provides a small set of utility functions for working with tree-like data structures, such as nested tuples, lists, and dicts. We call these structures pytrees. They are trees in that they are defined recursively (any non-pytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) and can be operated on recursively (object identity equivalence is not preserved by mapping operations, and the structures cannot contain reference cycles).

The set of Python types that are considered pytree nodes (e.g. that can be mapped over, rather than treated as leaves) is extensible. There is a single module-level registry of types, and class hierarchy is ignored. By registering a new pytree node type, that type in effect becomes transparent to the utility functions in this file.

The primary purpose of this module is to enable the interoperability between user defined data structures and JAX transformations (e.g. jit). This is not meant to be a general purpose tree-like data structure handling library.

See the JAX pytrees note for examples.

List of Functions#

Partial(func, *args, **kw)

A version of functools.partial that works in pytrees.

all_leaves(iterable[, is_leaf])

Tests whether all elements in the given iterable are all leaves.

build_tree(treedef, xs)

param treedef

register_pytree_node(nodetype, flatten_func, ...)

Extends the set of types that are considered internal nodes in pytrees.


Extends the set of types that are considered internal nodes in pytrees.


param tree

tree_flatten(tree[, is_leaf])

Flattens a pytree.

tree_leaves(tree[, is_leaf])

Gets the leaves of a pytree.

tree_map(f, tree, *rest[, is_leaf])

Maps a multi-input function over pytree args to produce a new pytree.


param function

tree_structure(tree[, is_leaf])

Gets the treedef for a pytree.

tree_transpose(outer_treedef, inner_treedef, ...)

Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).

tree_unflatten(treedef, leaves)

Reconstructs a pytree from the treedef and the leaves.


param treedef


param treedef


Makes a tuple treedef from an iterable of child treedefs.