jax.tree_util.tree_mapΒΆ

jax.tree_util.tree_map(f, tree, *rest, is_leaf=None)[source]ΒΆ

Maps a multi-input function over pytree args to produce a new pytree.

Parameters
  • f (Callable[…, Any]) – function that takes 1 + len(rest) arguments, to be applied at the corresponding leaves of the pytrees.

  • tree (Any) – a pytree to be mapped over, with each leaf providing the first positional argument to f.

  • *rest – a tuple of pytrees, each of which has the same structure as tree or or has tree as a prefix.

  • is_leaf (Optional[Callable[[Any], bool]]) – an optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.

  • rest (Any) –

Return type

Any

Returns

A new pytree with the same structure as tree but with the value at each leaf given by f(x, *xs) where x is the value at the corresponding leaf in tree and xs is the tuple of values at corresponding nodes in rest.