jax.tree_util module

jax.tree_util module#

Utilities for working with tree-like container data structures.

This module provides a small set of utility functions for working with tree-like data structures, such as nested tuples, lists, and dicts. We call these structures pytrees. They are trees in that they are defined recursively (any non-pytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) and can be operated on recursively (object identity equivalence is not preserved by mapping operations, and the structures cannot contain reference cycles).

The set of Python types that are considered pytree nodes (e.g. that can be mapped over, rather than treated as leaves) is extensible. There is a single module-level registry of types, and class hierarchy is ignored. By registering a new pytree node type, that type in effect becomes transparent to the utility functions in this file.

The primary purpose of this module is to enable the interoperability between user defined data structures and JAX transformations (e.g. jit). This is not meant to be a general purpose tree-like data structure handling library.

See the JAX pytrees note for examples.

List of Functions#

Partial(func, *args, **kw)

A version of functools.partial that works in pytrees.

all_leaves(iterable[, is_leaf])

Tests whether all elements in the given iterable are all leaves.

build_tree(treedef, xs)

param treedef:

register_pytree_node(nodetype, flatten_func, ...)

Extends the set of types that are considered internal nodes in pytrees.

register_pytree_node_class(cls)

Extends the set of types that are considered internal nodes in pytrees.

register_pytree_with_keys(nodetype, ...[, ...])

Extends the set of types that are considered internal nodes in pytrees.

register_pytree_with_keys_class(cls)

Extends the set of types that are considered internal nodes in pytrees.

tree_all(tree)

Call all() over the leaves of a tree.

tree_flatten(tree[, is_leaf])

Flattens a pytree.

tree_flatten_with_path(tree[, is_leaf])

Flattens a pytree like tree_flatten, but also returns each leaf's key path.

tree_leaves(tree[, is_leaf])

Gets the leaves of a pytree.

tree_leaves_with_path(tree[, is_leaf])

Gets the leaves of a pytree like tree_leaves and returns each leaf's key path.

tree_map(f, tree, *rest[, is_leaf])

Maps a multi-input function over pytree args to produce a new pytree.

tree_map_with_path(f, tree, *rest[, is_leaf])

Maps a multi-input function over pytree key path and args to produce a new pytree.

tree_reduce(function, tree[, initializer, ...])

Call reduce() over the leaves of a tree.

tree_structure(tree[, is_leaf])

Gets the treedef for a pytree.

tree_transpose(outer_treedef, inner_treedef, ...)

Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).

tree_unflatten(treedef, leaves)

Reconstructs a pytree from the treedef and the leaves.

treedef_children(treedef)

param treedef:

treedef_is_leaf(treedef)

param treedef:

treedef_tuple(treedefs)

Makes a tuple treedef from an iterable of child treedefs.

keystr(keys)

Helper to pretty-print a tuple of keys.