jax.tree module

jax.tree module#

Utilities for working with tree-like container data structures.

The jax.tree namespace contains aliases of utilities from jax.tree_util.

List of Functions#

all(tree)

Call all() over the leaves of a tree.

flatten(tree[, is_leaf])

Flattens a pytree.

leaves(tree[, is_leaf])

Gets the leaves of a pytree.

map(f, tree, *rest[, is_leaf])

Maps a multi-input function over pytree args to produce a new pytree.

reduce(function, tree[, initializer, is_leaf])

Call reduce() over the leaves of a tree.

structure(tree[, is_leaf])

Gets the treedef for a pytree.

transpose(outer_treedef, inner_treedef, ...)

Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).

unflatten(treedef, leaves)

Reconstructs a pytree from the treedef and the leaves.