- jax.lax.map(f, xs)[source]#
Map a function over leading array axes.
Like Python’s builtin map, except inputs and outputs are in the form of stacked arrays. Consider using the
vmap()transform instead, unless you need to apply a function element by element for reduced memory usage or heterogeneous computation with other control flow primitives.
xsis an array type, the semantics of
map()are given by this Python implementation:
def map(f, xs): return np.stack([f(x) for x in xs])
map()is implemented in terms of JAX primitives so many of the same advantages over a Python loop apply:
xsmay be an arbitrary nested pytree type, and the mapped computation is compiled only once.
f – a Python function to apply element-wise over the first axis or axes of
xs – values over which to map along the leading axis.