jax.tree_util.tree_reduce

Contents

jax.tree_util.tree_reduce#

jax.tree_util.tree_reduce(function, tree, initializer=<object object>, is_leaf=None)[source]#

Call reduce() over the leaves of a tree.

Parameters:
  • function (Callable[[T, Any], T]) – the reduction function

  • tree (Any) – the pytree to reduce over

  • initializer (Any) – the optional initial value

  • is_leaf (Callable[[Any], bool] | None) – an optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.

Returns:

the reduced value.

Return type:

result

Examples

>>> import jax
>>> import operator
>>> jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]])
21