jax.tree_util.tree_map_with_path#
- jax.tree_util.tree_map_with_path(f, tree, *rest, is_leaf=None)[source]#
Maps a multi-input function over pytree key path and args to produce a new pytree.
This is a more powerful alternative of
tree_map
that can take the key path of each leaf as input argument as well.- Parameters:
f (Callable[…, Any]) – function that takes
2 + len(rest)
arguments, aka. the key path and each corresponding leaves of the pytrees.tree (Any) – a pytree to be mapped over, with each leaf’s key path as the first positional argument and the leaf itself as the second argument to
f
.*rest (Any) – a tuple of pytrees, each of which has the same structure as
tree
or hastree
as a prefix.is_leaf (Callable[[Any], bool] | None)
- Return type:
Any
- Returns:
A new pytree with the same structure as
tree
but with the value at each leaf given byf(kp, x, *xs)
wherekp
is the key path of the leaf at the corresponding leaf intree
,x
is the leaf value andxs
is the tuple of values at corresponding nodes inrest
.