jax.experimental.maps module

API

jax.experimental.maps.mesh(devices, axis_names)[source]

Declare the hardware resources available in the scope of this manager.

In particular, all axis_names become valid resource names inside the managed block and can be used e.g. in the axis_resources argument of xmap().

Parameters
  • devices (ndarray) – A NumPy ndarray object containing JAX device objects (as obtained e.g. from jax.devices()).

  • axis_names (Sequence[Hashable]) – A sequence of resource axis names to be assigned to the dimensions of the devices argument. Its length should match the rank of devices.

Example:

devices = np.array(jax.devices())[:4].reshape((2, 2))
with mesh(devices, ('x', 'y')):  # declare a 2D mesh with axes 'x' and 'y'
  distributed_out = xmap(
    jnp.vdot,
    in_axes=({0: 'left', 1: 'right'}),
    out_axes=['left', 'right', ...],
    axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
jax.experimental.maps.xmap(fun, in_axes, out_axes, *, axis_sizes={}, axis_resources={}, donate_argnums=(), backend=None)[source]

Assign a positional signature to a program that uses named array axes.

Warning

This is an experimental feature and the details can change at any time. Use at your own risk!

Warning

This docstring is aspirational. Not all features of the named axis programming model have been implemented just yet.

The usual programming model of JAX (or really NumPy) associates each array with two pieces of metadata describing its type: the element type (dtype) and the shape. xmap() extends this model by adding support for named axes. In particular, each array used in a function wrapped by xmap() can additionally have a non-empty named_shape attribute, which can be used to query the set of named axes (introduced by xmap()) appearing in that value along with their shapes. Furthermore, in most places where positional axis indices are allowed (for example the axes arguments in sum()), bound axis names are also accepted. The einsum() language is extended inside xmap() to additionally allow contractions that involve named axes. Broadcasting of named axes happens by name, i.e. all axes with equal names are expected to have equal shapes in all arguments of a broadcasting operation, while the result has a (set) union of all named axes. The positional semantics of the program remain unchanged, and broadcasting still implicitly right-aligns positional axes for unification. For an extended description of the xmap() programming model, please refer to the xmap() tutorial notebook in main JAX documentation.

Note that since all top-level JAX expressions are interpreted in the NumPy programming model, xmap() can also be seen as an adapter that converts a function that uses named axes (including in arguments and returned values) into one that takes and returns values that only have positional axes.

The default lowering strategy of xmap() converts all named axes into positional axes, working similarly to multiple applications of vmap(). However, this behavior can be further customized by the axis_resources argument. When specified, each axis introduced by xmap() can be assigned to one or more resource axes. Those include the axes of the hardware mesh, as defined by the mesh() context manager. Each value that has a named axis in its named_shape will be partitioned over all mesh axes that axis is assigned to. Hence, xmap() can be seen as an alternative to pmap() that also exposes a way to automatically partition the computation over multiple devices.

Warning

While it is possible to assign multiple axis names to a single resource axis, care has to be taken to ensure that none of those named axes co-occur in a named_shape of any value in the named program. At the moment this is completely unchecked and will result in undefined behavior. The final release of xmap() will enforce this invariant, but it is a work in progress.

Note that you do not have to worry about any of this for as long as no resource axis is repeated in axis_resources.values().

Note that any assignment of axis_resources doesn’t ever change the results of the computation, but only how it is carried out (e.g. how many devices are used). This makes it easy to try out various ways of partitioning a single program in many distributed scenarios (both small- and large-scale), to maximize the performance. As such, xmap() can be seen as a way to seamlessly interpolate between vmap() and pmap()-style execution.

Parameters
  • fun (Callable) – Function that uses named axes. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof (in general: valid pytrees).

  • in_axes – A Python object with the same container (pytree) structure as the signature of arguments to fun, but with a positional-to-named axis mapping in place of every array argument. The valid positional-to-named mappings are: (1) a Dict[int, AxisName] specifying that a positional dimensions given by dictionary keys are to be converted to named axes of given names (2) a list of axis names that ends with the Ellipsis object (...) in which case a number of leading positional axes of the argument will be converted into named axes inside the function. Note that in_axes can also be a prefix of the argument container structure, in which case the mapping is repeated for all arrays in the collapsed subtree.

  • out_axes – A Python object with the same container (pytree) structure as the returns of fun, but with a positional-to-named axis mapping in place of every returned array. The valid positional-to-named mappings are the same as in in_axes. Note that out_axes can also be a prefix of the return container structure, in which case the mapping is repeated for all arrays in the collapsed subtree.

  • axis_sizes (Dict[Hashable, int]) – A dict mapping axis names to their sizes. All axes defined by xmap have to appear either in in_axes or axis_sizes. Sizes of axes that appear in in_axes are inferred from arguments whenever possible. In multi-host scenarios, the user-specified sizes are expected to be the global axis sizes (and might not match the expected size of local inputs).

  • axis_resources (Dict[Hashable, Union[Hashable, SerialLoop, Tuple[Union[Hashable, SerialLoop], …]]]) – A dictionary mapping the axes introduced in this xmap() to one or more resource axes. Any array that has in its shape an axis with some resources assigned will be partitioned over the resources associated with the respective resource axes.

  • backend (Optional[str]) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend. ‘cpu’, ‘gpu’, or ‘tpu’.

Returns

A version of fun that takes in arrays with positional axes in place of named axes bound in this xmap() call, and results with all named axes converted to positional axes. If axis_resources is specified, fun can additionally execute in parallel on multiple devices.

For example, xmap() makes it very easy to convert a function that computes the vector inner product (such as jax.numpy.vdot()) into one that computes a matrix multiplication:

>>> import jax.numpy as jnp
>>> x = jnp.arange(10).reshape((2, 5))
>>> xmap(jnp.vdot,
...      in_axes=({0: 'left'}, {1: 'right'}),
...      out_axes=['left', 'right', ...])(x, x.T)
DeviceArray([[ 30,  80],
             [ 80, 255]], dtype=int32)

Note that the contraction in the program is performed over the positional axes, while named axes are just a convenient way to achieve batching. While this might seem like a silly example at first, it might turn out to be useful in practice, since with conjuction with axis_resources this makes it possible to implement a distributed matrix-multiplication in just a few lines of code:

devices = np.array(jax.devices())[:4].reshape((2, 2))
with mesh(devices, ('x', 'y')):  # declare a 2D mesh with axes 'x' and 'y'
  distributed_out = xmap(
    jnp.vdot,
    in_axes=({0: 'left'}, {1: 'right'}),
    out_axes=['left', 'right', ...],
    axis_resources={'left': 'x', 'right': 'y'})(x, x.T)

Still, the above examples are quite simple. After all, the xmapped computation was a simple NumPy function that didn’t use the axis names at all! So, let’s explore a slightly larger example which is linear regression:

def regression_loss(x, y, w, b):
  # Contract over in_features. Batch and out_features are present in
  # both inputs and output, so they don't need to be mentioned
  y_pred = jnp.einsum('{in_features},{in_features}->{}', x, w) + b
  error = jnp.sum((y - y_pred) ** 2, axis='out_features')
  return jnp.mean(error, axis='batch')

xmap(regression_loss,
     in_axes=(['batch', 'in_features', ...],
              ['batch', 'out_features', ...],
              ['in_features', 'out_features', ...],
              ['out_features', ...]),
     out_axes={})  # Loss is reduced over all axes, including batch!

Note

When using axis_resources along with a mesh that is controlled by multiple JAX hosts, keep in mind that in any given process xmap() only expects the data slice that corresponds to its local devices to be specified. This is in line with the current multi-host pmap() programming model.

Parameters

donate_argnums (Union[int, Sequence[int]]) –