jax.tree_util.tree_map_with_path

jax.tree_util.tree_map_with_path#

jax.tree_util.tree_map_with_path(f, tree, *rest, is_leaf=None)[source]#

Maps a multi-input function over pytree key path and args to produce a new pytree.

This is a more powerful alternative of tree_map that can take the key path of each leaf as input argument as well.

Parameters:
  • f (Callable[…, Any]) – function that takes 2 + len(rest) arguments, aka. the key path and each corresponding leaves of the pytrees.

  • tree (Any) – a pytree to be mapped over, with each leaf’s key path as the first positional argument and the leaf itself as the second argument to f.

  • *rest (Any) – a tuple of pytrees, each of which has the same structure as tree or has tree as a prefix.

  • is_leaf (Callable[[Any], bool] | None)

Return type:

Any

Returns:

A new pytree with the same structure as tree but with the value at each leaf given by f(kp, x, *xs) where kp is the key path of the leaf at the corresponding leaf in tree, x is the leaf value and xs is the tuple of values at corresponding nodes in rest.