# SPMD multi-device parallelism with `shard_map`

#

`shard_map`

is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or *instances*, communicate with each other via explicit collective communication operations.

`shard_map`

is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`

. With `jit`

you write code as if for a single device, and the compiler can automatically partition computation over multiple devices, generating per-device code and communication collectives behind the scenes. With `shard_map`

you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.

If you’re familiar with `pmap`

, think of `shard_map`

as an evolution. It’s more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see a detailed comparison to `pmap`

.)

By reading this tutorial, you’ll learn how to use `shard_map`

to get full control over your multi-device code. You’ll see in detail how it composes with `jax.jit`

’s automatic parallelization and `jax.grad`

’s automatic differentiation. We’ll also give some basic examples of neural network parallelization strategies.

We’ll assume this tutorial is being run in an environment with eight devices”

```
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
```

## So, let’s see a `shard_map`

!#

Without further ado, here’s a toy example:

```
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
```

```
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))
a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 4.).reshape(16, 4)
@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
out_specs=P('x', None))
def matmul_basic(a_block, b_block):
# a_block: f32[2, 8]
# b_block: f32[8, 4]
c_partialsum = jnp.dot(a_block, b_block)
c_block = jax.lax.psum(c_partialsum, 'y')
# c_block: f32[2, 4]
return c_block
c = matmul_basic(a, b) # c: f32[8, 4]
```

This function computes a matrix multiply in parallel by performing local block matrix multiplies followed by a collective sum operation. We can check the result is correct:

```
from jax.tree_util import tree_map, tree_all
def allclose(a, b):
return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
allclose(c, jnp.dot(a, b))
```

```
True
```

The result is sharded along its rows:

```
jax.debug.visualize_array_sharding(c)
```

CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7

At a high level, `shard_map`

is kind of like `vmap`

or `pmap`

, in that we’re
mapping a function over pieces of array data, but notice that

`shard_map`

slices up inputs into blocks (and the output is formed by concatenating result blocks), keeping the rank the same, whereas`vmap`

would reduce the rank by mapping away an axis;the

`mesh`

argument lets us control precise device placement of computation and results;we’re mapping over multiple data axes at once, and setting up multiple axis names for collectives (both

`'x'`

and`'y'`

here);since we’re not using

`jax.jit`

yet, everything is eagerly evaluated, and we can even`print`

intermediate values for debugging.

The above code is performing the same computation as this `jax.jit`

automatic parallelization code:

```
from jax.sharding import NamedSharding
a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
b = jax.device_put(b, NamedSharding(mesh, P('y', None)))
@jax.jit
def matmul_reference(a, b):
c = jnp.dot(a, b)
return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))
c_ref = matmul_reference(a, b)
allclose(c_ref, jnp.dot(a, b))
```

```
True
```

We can think of `shard_map`

as performing a `device_put`

or
`with_sharding_constraint`

on its inputs according to its `mesh`

and `in_specs`

arguments, so the blocks over which `matmul_basic`

operates are the same as in
`matmul_reference`

:

```
print('a blocks:'); jax.debug.visualize_array_sharding(a)
print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c)
```

```
a blocks:
b blocks:
c blocks:
```

CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7

CPU 0,2,4,6 CPU 1,3,5,7

CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7

## Slow down, start with the basics!#

### Rank-reducing vs rank-preserving maps#

We can think of `vmap`

and `pmap`

as unstacking each array input along an axis
(e.g. unpacking a 2D matrix into its 1D rows), applying its body function to
each piece, and stacking the results back together, at least when collectives
aren’t involved:

```
def check_vmap(f, xs):
ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
expected = jnp.stack([f(x) for x in xs]) # vmap reference semantics
print(allclose(ans, expected))
check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))
```

```
True
```

For example, if `xs`

had shape `f32[8,5]`

then each `x`

would have shape
`f32[5]`

, and if each `f(x)`

had shape `f32[3,7]`

then the final stacked result
`vmap(f)(xs)`

would have shape `f32[8,3,7]`

. That is, each application of the
body function `f`

takes as argument inputs with one fewer axis than the
corresponding argument to `vmap(f)`

. We can say these are *rank-reducing maps*
with unstacking/stacking of inputs/outputs.

The number of logical applications of `f`

, or *instances* of `f`

, is determined
by the size of the input axis being mapped over: for example, if we map over an
input axis of size 8, semantically we get 8 logical applications of the
function.

In contrast, `shard_map`

does not have this rank-reducing behavior. Instead, we
can think of it as slicing (or “unconcatenating”) along input axes into blocks,
applying the body function, and concatenating the results back together (again
when collectives aren’t involved):

```
import numpy as np
devices = np.array(jax.devices()[:4])
mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4
def check_shmap(f, y):
ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)
expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
print(allclose(ans, expected))
check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))
```

```
True
```

Recall that jnp.split slices its input into equally-sized blocks with the same
rank, so that if in the above example `y`

had shape `f32[8,5]`

then each
`y_blk`

would have shape `f32[2,5]`

, and if each `f(y_blk)`

had shape
`f32[3,7]`

then the final concatenated result `shard_map(f, ...)(y)`

would have
shape `f32[12,7]`

. So `shard_map`

maps over *shards*, or blocks, of its inputs.
We can say it’s a *rank-preserving map* with unconcatenating/concatenating of
its inputs/outputs.

The number of logical applications of `f`

is determined by the mesh size, not
by any input axis size: for example, if we have a mesh of total size 4 (i.e.
over 4 devices) then semantically we get 4 logical applications of the
function, corresponding to the 4 devices physically computing them.

### Controlling how each input is split (unconcatenated) and tiled with `in_specs`

#

Each of the `in_specs`

identifies some of the corresponding input array’s axes
with mesh axes by name using `PartitionSpec`

s, representing how to split (or
unconcatenate) that input into the blocks to which the body function is
applied. That identification determines the shard sizes; when an input axis is
identified with a mesh axis, the input is split (unconcatenated) along that
logical axis into a number of pieces equal to the corresponding mesh axis size.
(It’s an error if the corresponding mesh axis size does not evenly divide the
input array axis size.) If an input’s pspec does not mention a mesh axis name,
then there’s no splitting over that mesh axis. For example:

```
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('i', 'j'))
@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
print(x_block.shape) # prints (3, 12)
return x_block
x1 = jnp.arange(12 * 12).reshape(12, 12)
y = f1(x1)
```

```
(3, 12)
```

Here, because the input pspec did not mention the mesh axis name `'j'`

, no
input array axis is split over that mesh axis; similarly, because the second
axis of the input array is not identified with (and hence split over) any mesh
axis, application of `f1`

gets a full view of the input along that axis.

When a mesh axis is not mentioned in an input pspec, we can always rewrite to a
less efficient program where all mesh axes are mentioned but the caller
performs a `jnp.tile`

, for example:

```
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
print(x_block.shape)
return x_block
x = jnp.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.shape['j'])) # x_ has shape (12, 24)
y = f2(x_) # prints (3,12), and f1(x) == f2(x_)
```

```
(3, 12)
```

In other words, because each input pspec can mention each mesh axis name zero
or one times, rather than having to mention each name exactly once, we can say
that in addition to the `jnp.split`

built into its input, `shard_map`

also has
a `jnp.tile`

built into its input, at least logically (though the tiling may
not need to be carried out physically, depending on the arguments’ physical
sharding layout). The tiling to use is not unique; we could also have tiled
along the first axis, and used the pspec `P(('j', 'i'), None)`

.

Physical data movement is possible on inputs, as each device needs to have a copy of the appropriate data.

### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs`

#

Analogously to the input side, each of the `out_specs`

identifies some of the
corresponding output array’s axes with mesh axes by name, representing how the
output blocks (one for each application of the body function, or equivalently
one for each physical device) should be assembled back together to form the
final output value. For example, in both the `f1`

and `f2`

examples above the
`out_specs`

indicate we should form the final output by concatenating together
the block results along both axes, resulting in both cases an array `y`

of
shape `(12, 24)`

. (It’s an error if an output shape of the body function, i.e.
an output block shape, has a rank too small for the concatenation described by
the corresponding output pspec.)

When a mesh axis name is not mentioned in an output pspec, it represents an un-tiling: when the user writes an output pspec which does not mention one of the mesh axis names, they promise that the output blocks are equal along that mesh axis, and so only one block along that axis is used in the output (rather than concatenating all the blocks together along that mesh axis). For example, using the same mesh as above:

```
x = jnp.array([[3.]])
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()
print(z) # prints the same as jnp.tile(x, (4, 2))
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()
print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()
print(z) # prints the same as jnp.tile(x, (1, 1)), or just x
```

```
[[3. 3.]
[3. 3.]
[3. 3.]
[3. 3.]]
[[3.]
[3.]
[3.]
[3.]]
[[3.]]
```

The body function closing over an array value is equivalent to passing it as an augment with a corresponding input pspec of P(None, None). As another example, following more closely to the other examples above:

```
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
return jax.lax.psum(x_block, 'j')
x = jnp.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)
```

```
(12, 6)
```

The result has a second axis size of 6, half the size of the input’s second
axis. In this case, the un-tile expressed by not mentioning the mesh axis name
`'j'`

in the output pspec was safe because of the collective `psum`

, which
ensures each output block is equal along the corresponding mesh axis. Here are
two more examples where we vary which mesh axes are mentioned in the output
pspec:

```
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
return jax.lax.psum(x_block, 'i')
x = jnp.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape) # (3,12)
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
return jax.lax.psum(x_block, ('i', 'j'))
y5 = f5(x)
print(y5.shape) # (3,6)
```

```
(3, 12)
(3, 6)
```

On the physical side, not mentioning a mesh axis name in an output pspec
assembles an `Array`

from the output device buffers with replicated layout
along that mesh axis.

There is no runtime check that the output blocks are actually equal along a mesh axis to be un-tiled along, or equivalently that the corresponding physical buffers have equal values and thus can be interpreted as a replicated layout for a single logical array. But we can provide a static check mechanism which raises an error on all potentially-incorrect programs.

Because the `out_specs`

can mention mesh axis names zero or one times, and
because they can be mentioned in any order, we can say that in addition to the
`jnp.concatenate`

built into its output, `shard_map`

also has both an *untile*
and a *block transpose* built into its output.

Physical data movement is not possible on outputs, no matter the output pspec.
Instead, `out_specs`

just encodes how to assemble the block outputs into
`Array`

s, or physically how to interpret the buffers across devices as the
physical layout of a single logical `Array`

.

# API Specification#

```
from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]
def shard_map(
f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
auto: collections.abc.Set[AxisName] = frozenset([]),
check_rep: bool = True,
) -> Callable:
...
```

where:

communication collectives like

`psum`

in the body of`f`

can mention the axis names of`mesh`

;`mesh`

encodes devices arranged in an array and with associated axis names, just like it does for`sharding.NamedSharding`

;`in_specs`

and`out_specs`

are`PartitionSpec`

s which can affinely mention axis names from`mesh`

to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;`auto`

is an optional set of axis names corresponding to the subset of names of`mesh`

to treat automatically in the body, as in the caller, rather than manually;`check_rep`

is an optional boolean indicating whether to check statically for any replication errors in`out_specs`

, and also whether to enable a related automatic differentiation optimization (see JEP).

The shapes of the arguments passed to `f`

have the same ranks as the arguments
passed to `shard_map`

-of-`f`

, and the shape of an argument to `f`

is computed
from the shape `shape`

of the corresponding argument to `shard_map`

-of-`f`

and
the corresponding `PartitionSpec`

`spec`

as roughly
`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`

.

# Collectives tutorial#

A `shard_map`

need not be a pure map: function applications can communicate
with each other via *collectives*, using axis names defined in the `mesh`

argument.

Recall that `shard_map`

maps a function over shards, or blocks, of input data,
so that this:

```
mesh = Mesh(jax.devices(), ('i',))
x = jnp.arange(16.)
f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))
y = f_shmapped(x)
```

Computes the same values, evaluating applications of `f`

to the same argument
values, as this reference function:

```
def f_shmapped_ref(x):
x_blocks = jnp.array_split(x, mesh.shape[0])
y_blocks = [f(x_blk) for x_blk in x_blocks]
return jnp.concatenate(y_blocks)
```

We call these applications of `f`

to different argument shards *function
instances*. Each function instance is executed on a different device (or subset
of devices).

These reference semantics work when `f`

has no communication collectives in
it. But what if we want the function instances to communicate, corresponding
to having cross-device communication? That is, what are the reference
semantics when `f`

contains a collective? Say `f`

has just one collective, and
is of the form

```
def f(x_blk):
z_blk = f_part1(x_blk)
u_blk = collective(z_blk, axis_name)
v_blk = f_part2(x_blk, z_blk, u_blk)
return v_blk
```

where we’re assuming there’s only one mesh axis we’re mapping over, and
`axis_name`

is the corresponding name for it. Then the reference semantics
would look more like:

```
def f_shmapped_ref(x):
x_blocks = jnp.array_split(x, mesh.shape[0])
z_blocks = [f_part1(x_blk) for x_blk in x_blocks]
u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))]
v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk
in zip(x_blocks, z_blocks, u_blocks)]
return jnp.concatenate(v_blocks)
```

Notice that `collective_ref`

might depend on all the `z_blocks`

. That is,
while `f_part1`

and `f_part2`

are mapped over blocks independently, a
collective introduces some amount of cross-block dependence. Physically, that
means communication across devices. Exactly what communication happens, and
what values are computed, depend on the collective.

`psum`

#

The simplest collective may be `jax.lax.psum`

, which computes an
all-reduce-sum along a device mesh axis (or multiple axes).
Here’s a toy example:

```
import jax
import jax.numpy as jnp
from jax import lax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
```

```
mesh1d = Mesh(jax.devices()[:4], ('i',))
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))
def f1(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, 'i')
print('AFTER:\n', y_block)
return y_block
```

```
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f1(x)
print('FINAL RESULT:\n', y)
```

```
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22 20 12 17]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[22 20 12 17]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[22 20 12 17]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[22 20 12 17]
FINAL RESULT:
[22 20 12 17]
```

The prints show that each function application starts with its own chunk of
the argument value `x_block`

. After the `psum`

, each function application has
the same value of `y_block`

, computed by summing the applications’ `x_block`

values together.

In the case where there’s a single axis name in the computation, we could say
that the `collective_ref`

reference implementation for `psum`

is

```
def psum_ref(_, x_blocks):
tot = sum(x_blocks)
return [tot] * len(x_blocks)
```

Notice also that because `f1`

returns `y_block`

, the result of a `psum`

over
`'i'`

, we can use `out_specs=P()`

so the caller gets a single logical copy of
the result value, rather than a tiled result.

When there is more than one mesh axis, we can perform a `psum`

over
each one separately, or over multiple axes at once:

```
mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))
@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f2(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, 'i')
print('AFTER:\n', y_block)
return y_block
y = f2(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
```

```
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
[4 5]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
[6 7]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 9]
[12 13]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
[14 15]]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[ 8 10]
[16 18]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[12 14]
[20 22]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 10]
[16 18]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[12 14]
[20 22]]
FINAL RESULT:
[[ 8 10 12 14]
[16 18 20 22]]
```

By applying a `psum`

over mesh axis `'i'`

, we get values of `y_block`

which
are equal along axis ‘`i'`

, but not axis `'j'`

. (So we can use
`out_specs=P(None, 'j')`

to get a single logical result along that axis.)

If we apply the `psum`

over both axes, the `y_block`

value is equal along both
axes:

```
@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))
def f3(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum(x_block, ('i', 'j'))
print('AFTER:\n', y_block)
return y_block
y = f3(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
```

```
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
[4 5]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
[6 7]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 9]
[12 13]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
[14 15]]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[20 24]
[36 40]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[20 24]
[36 40]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[20 24]
[36 40]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[20 24]
[36 40]]
FINAL RESULT:
[[20 24]
[36 40]]
```

In machine learning, we often use `psum`

to compute total losses or, when we
have a `grad`

inside the `shard_map`

ped function body, total gradients.

In the sequel, we’ll see how `psum`

can be implemented in terms of other
primitives, which gives some intuition about its communication cost.

`all_gather`

#

Another fundamental operation is gathering array shards along an axis, so that each function application has a full copy of the data along that axis:

```
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f4(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_gather(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 9, 5, 2])
y = f4(x)
print('FINAL RESULT:\n', y)
```

```
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 9 5 2]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3 9 5 2]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[3 9 5 2]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[3 9 5 2]
FINAL RESULT:
[3 9 5 2 3 9 5 2 3 9 5 2 3 9 5 2]
```

The prints show that each function application again starts with its own chunk
of the argument value `x_block`

. After the `all_gather`

, they have a common
value, computed by concatenating the values of `x_block`

.

(Notice that we actually can’t set `out_specs=P()`

here. For technical
reasons related to automatic differentiation, we consider the output of
`all_gather`

not to be guaranteed invariant across devices. If we wanted it to
be guaranteed invariant, we could use `jax.lax.all_gather_invariant`

, or in
this case we could just avoid doing the `all_gather`

in the function body and
instead just use `out_specs=P('i')`

to perform the concatenation.)

When `tiled=False`

(the default), results are stacked along a new axis instead
of concatenated:

```
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f5(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_gather(x_block, 'i', tiled=False)
print('AFTER:\n', y_block)
return y_block
y = f5(x)
print('FINAL RESULT:\n', y)
```

```
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[[3]
[9]
[5]
[2]]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[[3]
[9]
[5]
[2]]
FINAL RESULT:
[[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]
[3]
[9]
[5]
[2]]
```

We could write the `collective_ref`

reference semantics function for
`all_gather`

as

```
def all_gather_ref(_, x_blocks, *, tiled=False):
combine = jnp.concatenate if tiled else jnp.stack
return [combine(x_blocks)] * len(x_blocks)
```

In deep learning, we might use `all_gather`

s on parameters in fully sharded
data parallelism (FSDP).

`psum_scatter`

#

The `jax.lax.psum_scatter`

collective is a bit less intuitive. It’s like
`psum`

except each function instance gets only one shard of the result:

```
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f6(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f6(x)
print('FINAL RESULT:\n', y)
```

```
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
[22 20 12 17]
```

As shown by the prints, each resulting `y_block`

has a smaller size than the
argument `x_block`

, unlike with `psum`

. Moreover, compared to `psum`

, here
each `y_block`

only represents a slice of the sum of the `x_block`

s across
function instances. (Even though each function instance gets only one shard of
the sum, the final output `y`

is the same as in the `psum`

example because
here we use `out_specs=P('i')`

to concatenate each function instance’s
output.)

In terms of what values are computed, a `collective_ref`

reference
implementation might look like:

```
def psum_scatter_ref(i, x_blocks, *, tiled=False):
axis_size = len(x_blocks)
tot = sum(x_blocks)
if tiled:
tot = tot.reshape(axis_size, -1, *tot.shape[1:]) # split leading axis
return [tot[i] for i in range(tot.shape[0])]
```

It’s not captured in the semantics reference implementation, but
`psum_scatter`

is useful because these results can be computed more
efficiently, with less communication, than a full `psum`

. In fact, one way to
think of `psum_scatter`

is as “the first half of a `psum`

, before an
`all_gather`

”. That is, one way to implement `psum`

is:

```
def psum(x, axis_name):
summed_chunk = jax.lax.psum_scatter(x, axis_name)
return jax.lax.all_gather(summed_chunk, axis_name)
```

Indeed, this implementation is often used on both TPU and GPU!

The reason `psum_scatter`

can require about half the communication as a full
`psum`

is illustrated the `ppermute`

section.

Another intuition is that we can use `psum_scatter`

to implement a distributed
matrix multiplication with inputs and outputs sharded over the same axis. In
machine learning, `psum_scatter`

can be used in tensor-parallel matrix
multiplies or fully-sharded data parallel gradient accumulation, as shown in
the examples to follow.

`ppermute`

#

The `jax.lax.ppermute`

collective provides the most direct way for
function instances to send data to one another. Given a mesh axis and a
list of `(source_index, destination_index)`

pairs representing indices along
that mesh axis, `ppermute`

sends its argument value from each source function
instance to each destination:

```
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f7(x_block):
sz = jax.lax.psum(1, 'i')
print('BEFORE:\n', x_block)
y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])
print('AFTER:\n', y_block)
return y_block
y = f7(jnp.arange(8))
print('FINAL RESULT:\n', y)
```

```
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[2 3]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 5]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[6 7]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[6 7]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[0 1]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[2 3]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[4 5]
FINAL RESULT:
[6 7 0 1 2 3 4 5]
```

In this case, with just two function instances, each instance’s value of
`y_block`

is the other’s value of `x_block`

.

Source indices and destination indices can’t be repeated. If an index does not appear as a destination, then the value of the corresponding function instance’s result is an array of zeros.

A `collective_ref`

reference implementation could look like

```
def ppermute_ref(i, x_blocks, perm):
results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks)
for src, dst in perm:
results[dst] = x_blocks[src]
return results
```

Other collectives can be implemented efficiently, in terms of total
communication, using `ppermute`

s where each function passes data only to its
neighbors. For example, we could implement `psum_scatter`

using a sequence of
`ppermute`

s and local additions this way:

Or, with a numerical example:

Intuitively, on each iteration each function instance sends ‘up’ the value it received on the previous iteration, and reduces (adds) the value it receives this iteration. In code, it might look like this:

```
def psum_scatter(x, axis_name, *, tiled=False):
size = jax.lax.psum(1, axis_name)
idx = jax.lax.axis_index(axis_name) # function instance index along axis_name
if tiled:
x = x.reshape(size, -1, *x.shape[1:]) # split leading axis
shift = partial(jax.lax.ppermute, axis_name=axis_name,
perm=[(i, (i - 1) % size) for i in range(size)])
for i in range(1, size):
update = shift(x[(idx + i) % size])
x = x.at[(idx + i + 1) % size].add(update)
return x[idx]
```

```
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f8(x_block):
print('BEFORE:\n', x_block)
y_block = psum_scatter(x_block, 'i', tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f8(x)
print('FINAL RESULT:\n', y)
```

```
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]
FINAL RESULT:
[22 20 12 17]
```

On TPU, there are higher-dimensional variants of this algorithm to exploit multiple bidirectional physical mesh axes.

Notice that `psum_scatter`

is the transpose of `all_gather`

. Indeed, a way to
implement `all_gather`

in terms of `ppermute`

looks like the reverse of the
above process:

In deep learning, we might use `ppermute`

when implementing SPMD pipeline
parallelism, where we divide our network along its depth into stages and
evaluate the applications of stages in parallel. Or we might use `ppermute`

in
parallelizing the evaluation of convolutional layers, where we shard over
spatial axes and thus devices must communicate “halos” to each other. Or it
may be used under-the-hood in tensor-parallel matrix multiplies.

`all_to_all`

#

A final collective is `all_to_all`

, which is essentially a block matrix
transpose operating along one positional axis and one cross-device axis:

```
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f9(x_block):
print('BEFORE:\n', x_block)
y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,
tiled=True)
print('AFTER:\n', y_block)
return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f9(x)
print('FINAL RESULT:\n', y)
```

```
BEFORE:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 5 5 9]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[1 9 3 7]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 2 5 1]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[1 6 8 2]
FINAL RESULT:
[3 5 5 9 1 9 3 7 4 2 5 1 1 6 8 2]
```

The `split_axis`

argument indicates which positional axis should be sharded
and partitioned across the mesh axis. The `concat_axis`

argument indicates the
axis along which the communicated results should be concatenated or stacked.

When `tiled=False`

(the default), the `split_axis`

axis size must equal the
size of the mesh axis named `axis_name`

, and a new axis of that size is
created at position `concat_axis`

for the stacked results. When `tiled=True`

,
the `split_axis`

axis size need only be evenly divisible by the size of the
mesh axis, and results are concatenated along the existing axis `concat_axis`

.

The `collective_ref`

reference semantics when `split_axis=0`

and
`concat_axis=0`

might look like:

```
def all_to_all_ref(_, x_blocks, *, tiled=False):
axis_size = len(x_blocks)
if tiled:
splits = [jnp.array_split(x, axis_size) for x in x_blocks]
return [jnp.concatenate(s) for s in zip(*splits)]
else:
splits = [list(x) for x in x_blocks]
return [jnp.stack(s) for s in zip(*splits)]
```

In deep learning, we might use `all_to_all`

in mixture-of-expert routing,
where we first sort our local batch of examples according to which expert they
should go to, then apply an `all_to_all`

to redistribute examples to experts.

# Toy examples#

How might we use `shard_map`

and collective communication in practice? These
examples, while simple, give some idea.

## Matrix multiplies#

Parallelizing matrix multiplication is central in scaling up deep learning
models, both for training and for inference. When `jax.jit`

automatically
parallelizes matrix multiplication, it can use one of several different
strategies, depending on matrix sizes, hardware details, and other factors. How
might we write some of those parallelized routines more explicitly using
`shard_map`

? And how can we optimize them to get better compute/communication
overlap and thus improve FLOP utilization?

```
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
```

```
mesh = Mesh(jax.devices()[:4], ('i',))
def device_put(x, pspec):
return jax.device_put(x, NamedSharding(mesh, pspec))
```

### Example 1: `all-gather`

on one side#

Consider performing a matrix multiplication where we shard the left-hand side argument (can think: parameters) on its leading (non-contracting) dimension:

```
lhs_spec = P('i', None)
lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)
```

And wee shard the right-hand side argument (can think: activations) on its contracting dimension, with a similar sharding for the output:

```
rhs_spec = P('i', None)
rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)
```

To perform this matrix multiplication, we can first all-gather the right-hand side and then perform local matrix multiplies against the sharded left-hand side:

```
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
out_specs=rhs_spec)
def matmul_allgather(lhs_block, rhs_block):
rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)
return lhs_block @ rhs
```

```
out = matmul_allgather(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
```

```
True
```

That’s great, but we’re not getting any compute/communication overlap
here: before we can start the matmul, we need the all_gather to complete.
Here’s a profile using the same code, but on larger example shapes (`(8192, 8192)`

for `lhs`

and `(8192, 1024)`

for `rhs`

):