jax.experimental.maps module#

API#

Mesh(devices, axis_names)

Declare the hardware resources available in the scope of this manager.

xmap(fun, in_axes, out_axes, *[, ...])

Assign a positional signature to a program that uses named array axes.