jax.experimental.maps module

Contents

jax.experimental.maps module#

API#

xmap(fun, in_axes, out_axes, *[, ...])

Assign a positional signature to a program that uses named array axes.