jax.tree_util packageÂ¶
Utilities for working with treelike container data structures.
This module provides a small set of utility functions for working with treelike data structures, such as nested tuples, lists, and dicts. We call these structures pytrees. They are trees in that they are defined recursively (any nonpytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) and can be operated on recursively (object identity equivalence is not preserved by mapping operations, and the structures cannot contain reference cycles).
The set of Python types that are considered pytree nodes (e.g. that can be mapped over, rather than treated as leaves) is extensible. There is a single modulelevel registry of types, and class hierarchy is ignored. By registering a new pytree node type, that type in effect becomes transparent to the utility functions in this file.
The primary purpose of this module is to enable the interoperability between user defined data structures and JAX transformations (e.g. jit). This is not meant to be a general purpose treelike data structure handling library.
See the JAX pytrees note for examples.

class
jax.tree_util.
Partial
[source]Â¶ A version of functools.partial that works in pytrees.
Use it for partial function evaluation in a way that is compatible with JAXâ€™s transformations, e.g.,
Partial(func, *args, **kwargs)
.(You need to explicitly optin to this behavior because we didnâ€™t want to give functools.partial different semantics than normal function closures.)

jax.tree_util.
all_leaves
(iterable)[source]Â¶ Tests whether all elements in the given iterable are all leaves.
>>> tree = {"a": [1, 2, 3]} >>> assert all_leaves(jax.tree_leaves(tree)) >>> assert not all_leaves([tree])
This function is useful in advanced cases, for example if a library allows arbitrary map operations on a flat list of leaves it may want to check if the result is still a flat list of leaves.
 Parameters
iterable â€“ Iterable of leaves.
 Returns
A boolean indicating if all elements in the input are leaves.

jax.tree_util.
register_pytree_node
(nodetype, flatten_func, unflatten_func)[source]Â¶ Extends the set of types that are considered internal nodes in pytrees.
See example usage.
 Parameters
nodetype (
Type
[~T]) â€“ a Python type to treat as an internal pytree node.flatten_func (
Callable
[[~T],Tuple
[Sequence
[Any
],Any
]]) â€“ a function to be used during flattening, taking a value of typenodetype
and returning a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable auxiliary data to be stored in the treedef and to be passed to theunflatten_func
.unflatten_func (
Callable
[[Any
,Sequence
[Any
]], ~T]) â€“ a function taking two arguments: the auxiliary data that was returned byflatten_func
and stored in the treedef, and the unflattened children. The function should return an instance ofnodetype
.

jax.tree_util.
register_pytree_node_class
(cls)[source]Â¶ Extends the set of types that are considered internal nodes in pytrees.
This function is a thin wrapper around
register_pytree_node
, and provides a classoriented interface:@register_pytree_node_class class Special:
 def __init__(self, x, y):
self.x = x self.y = y
 def tree_flatten(self):
return ((self.x, self.y), None)
@classmethod def tree_unflatten(cls, aux_data, children):
return cls(*children)

jax.tree_util.
tree_flatten
(tree, is_leaf=None)[source]Â¶ Flattens a pytree.
 Parameters
tree â€“ a pytree to flatten.
is_leaf (
Optional
[Callable
[[Any
],bool
]]) â€“ an optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.
 Returns
A pair where the first element is a list of leaf values and the second element is a treedef representing the structure of the flattened tree.

jax.tree_util.
tree_map
(f, tree)[source]Â¶ Maps a function over a pytree to produce a new pytree.
 Parameters
 Return type
 Returns
A new pytree with the same structure as tree but with the value at each leaf given by
f(x)
wherex
is the value at the corresponding leaf in the inputtree
.

jax.tree_util.
tree_multimap
(f, tree, *rest)[source]Â¶ Maps a multiinput function over pytree args to produce a new pytree.
 Parameters
f (
Callable
[â€¦,Any
]) â€“ function that takes1 + len(rest)
arguments, to be applied at the corresponding leaves of the pytrees.tree (
Any
) â€“ a pytree to be mapped over, with each leaf providing the first positional argument tof
.*rest â€“ a tuple of pytrees, each of which has the same structure as tree or or has tree as a prefix.
rest (
Any
) â€“
 Return type
 Returns
A new pytree with the same structure as
tree
but with the value at each leaf given byf(x, *xs)
wherex
is the value at the corresponding leaf intree
andxs
is the tuple of values at corresponding nodes inrest
.

jax.tree_util.
tree_unflatten
(treedef, leaves)[source]Â¶ Reconstructs a pytree from the treedef and the leaves.
The inverse of
tree_flatten()
. Parameters
treedef â€“ the treedef to reconstruct
leaves â€“ the list of leaves to use for reconstruction. The list must match the leaves of the treedef.
 Returns
The reconstructed pytree, containing the
leaves
placed in the structure described bytreedef
.