jax.experimental.maps moduleÂ¶
APIÂ¶
- jax.experimental.maps.mesh(devices, axis_names)[source]Â¶
Declare the hardware resources available in the scope of this manager.
In particular, all
axis_names
become valid resource names inside the managed block and can be used e.g. in theaxis_resources
argument ofxmap()
.- Parameters
devices (
ndarray
) â€“ A NumPy ndarray object containing JAX device objects (as obtained e.g. fromjax.devices()
).axis_names (
Sequence
[Hashable
]) â€“ A sequence of resource axis names to be assigned to the dimensions of thedevices
argument. Its length should match the rank ofdevices
.
Example:
devices = np.array(jax.devices())[:4].reshape((2, 2)) with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y' distributed_out = xmap( jnp.vdot, in_axes=({0: 'left', 1: 'right'}), out_axes=['left', 'right', ...], axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
- jax.experimental.maps.xmap(fun, in_axes, out_axes, *, axis_sizes={}, axis_resources={}, donate_argnums=(), backend=None)[source]Â¶
Assign a positional signature to a program that uses named array axes.
Warning
This is an experimental feature and the details can change at any time. Use at your own risk!
Warning
This docstring is aspirational. Not all features of the named axis programming model have been implemented just yet.
The usual programming model of JAX (or really NumPy) associates each array with two pieces of metadata describing its type: the element type (
dtype
) and theshape
.xmap()
extends this model by adding support for named axes. In particular, each array used in a function wrapped byxmap()
can additionally have a non-emptynamed_shape
attribute, which can be used to query the set of named axes (introduced byxmap()
) appearing in that value along with their shapes. Furthermore, in most places where positional axis indices are allowed (for example the axes arguments insum()
), bound axis names are also accepted. Theeinsum()
language is extended insidexmap()
to additionally allow contractions that involve named axes. Broadcasting of named axes happens by name, i.e. all axes with equal names are expected to have equal shapes in all arguments of a broadcasting operation, while the result has a (set) union of all named axes. The positional semantics of the program remain unchanged, and broadcasting still implicitly right-aligns positional axes for unification. For an extended description of thexmap()
programming model, please refer to thexmap()
tutorial notebook in main JAX documentation.Note that since all top-level JAX expressions are interpreted in the NumPy programming model,
xmap()
can also be seen as an adapter that converts a function that uses named axes (including in arguments and returned values) into one that takes and returns values that only have positional axes.The default lowering strategy of
xmap()
converts all named axes into positional axes, working similarly to multiple applications ofvmap()
. However, this behavior can be further customized by theaxis_resources
argument. When specified, each axis introduced byxmap()
can be assigned to one or more resource axes. Those include the axes of the hardware mesh, as defined by themesh()
context manager. Each value that has a named axis in itsnamed_shape
will be partitioned over all mesh axes that axis is assigned to. Hence,xmap()
can be seen as an alternative topmap()
that also exposes a way to automatically partition the computation over multiple devices.Warning
While it is possible to assign multiple axis names to a single resource axis, care has to be taken to ensure that none of those named axes co-occur in a
named_shape
of any value in the named program. At the moment this is completely unchecked and will result in undefined behavior. The final release ofxmap()
will enforce this invariant, but it is a work in progress.Note that you do not have to worry about any of this for as long as no resource axis is repeated in
axis_resources.values()
.Note that any assignment of
axis_resources
doesnâ€™t ever change the results of the computation, but only how it is carried out (e.g. how many devices are used). This makes it easy to try out various ways of partitioning a single program in many distributed scenarios (both small- and large-scale), to maximize the performance. As such,xmap()
can be seen as a way to seamlessly interpolate betweenvmap()
andpmap()
-style execution.- Parameters
fun (
Callable
) â€“ Function that uses named axes. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof (in general: valid pytrees).in_axes â€“ A Python object with the same container (pytree) structure as the signature of arguments to
fun
, but with a positional-to-named axis mapping in place of every array argument. The valid positional-to-named mappings are: (1) aDict[int, AxisName]
specifying that a positional dimensions given by dictionary keys are to be converted to named axes of given names (2) a list of axis names that ends with the Ellipsis object (...
) in which case a number of leading positional axes of the argument will be converted into named axes inside the function. Note thatin_axes
can also be a prefix of the argument container structure, in which case the mapping is repeated for all arrays in the collapsed subtree.out_axes â€“ A Python object with the same container (pytree) structure as the returns of
fun
, but with a positional-to-named axis mapping in place of every returned array. The valid positional-to-named mappings are the same as inin_axes
. Note thatout_axes
can also be a prefix of the return container structure, in which case the mapping is repeated for all arrays in the collapsed subtree.axis_sizes (
Dict
[Hashable
,int
]) â€“ A dict mapping axis names to their sizes. All axes defined by xmap have to appear either inin_axes
oraxis_sizes
. Sizes of axes that appear inin_axes
are inferred from arguments whenever possible. In multi-host scenarios, the user-specified sizes are expected to be the global axis sizes (and might not match the expected size of local inputs).axis_resources (
Dict
[Hashable
,Union
[Hashable
,SerialLoop
,Tuple
[Union
[Hashable
,SerialLoop
], â€¦]]]) â€“ A dictionary mapping the axes introduced in thisxmap()
to one or more resource axes. Any array that has in its shape an axis with some resources assigned will be partitioned over the resources associated with the respective resource axes.backend (
Optional
[str
]) â€“ This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend. â€˜cpuâ€™, â€˜gpuâ€™, or â€˜tpuâ€™.
- Returns
A version of
fun
that takes in arrays with positional axes in place of named axes bound in thisxmap()
call, and results with all named axes converted to positional axes. Ifaxis_resources
is specified,fun
can additionally execute in parallel on multiple devices.
For example,
xmap()
makes it very easy to convert a function that computes the vector inner product (such asjax.numpy.vdot()
) into one that computes a matrix multiplication:>>> import jax.numpy as jnp >>> x = jnp.arange(10).reshape((2, 5)) >>> xmap(jnp.vdot, ... in_axes=({0: 'left'}, {1: 'right'}), ... out_axes=['left', 'right', ...])(x, x.T) DeviceArray([[ 30, 80], [ 80, 255]], dtype=int32)
Note that the contraction in the program is performed over the positional axes, while named axes are just a convenient way to achieve batching. While this might seem like a silly example at first, it might turn out to be useful in practice, since with conjuction with
axis_resources
this makes it possible to implement a distributed matrix-multiplication in just a few lines of code:devices = np.array(jax.devices())[:4].reshape((2, 2)) with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y' distributed_out = xmap( jnp.vdot, in_axes=({0: 'left'}, {1: 'right'}), out_axes=['left', 'right', ...], axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
Still, the above examples are quite simple. After all, the xmapped computation was a simple NumPy function that didnâ€™t use the axis names at all! So, letâ€™s explore a slightly larger example which is linear regression:
def regression_loss(x, y, w, b): # Contract over in_features. Batch and out_features are present in # both inputs and output, so they don't need to be mentioned y_pred = jnp.einsum('{in_features},{in_features}->{}', x, w) + b error = jnp.sum((y - y_pred) ** 2, axis='out_features') return jnp.mean(error, axis='batch') xmap(regression_loss, in_axes=(['batch', 'in_features', ...], ['batch', 'out_features', ...], ['in_features', 'out_features', ...], ['out_features', ...]), out_axes={}) # Loss is reduced over all axes, including batch!
Note
When using
axis_resources
along with a mesh that is controlled by multiple JAX hosts, keep in mind that in any given processxmap()
only expects the data slice that corresponds to its local devices to be specified. This is in line with the current multi-hostpmap()
programming model.