tree_map(f, tree, *rest, is_leaf=None)¶
Maps a multi-input function over pytree args to produce a new pytree.
Any) – a pytree to be mapped over, with each leaf providing the first positional argument to
*rest – a tuple of pytrees, each of which has the same structure as tree or or has tree as a prefix.
bool]]) – an optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.
- Return type
A new pytree with the same structure as
treebut with the value at each leaf given by
xis the value at the corresponding leaf in
xsis the tuple of values at corresponding nodes in