SPMD multi-device parallelism with shard_map#

shard_map is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or instances, communicate with each other via explicit collective communication operations.

shard_map is complementary to, and composable with, the automatic compiler-based parallelization built into jit. With jit you write code as if for a single device, and the compiler can automatically partition computation over multiple devices, generating per-device code and communication collectives behind the scenes. With shard_map you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.

If you’re familiar with pmap, think of shard_map as an evolution. It’s more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see a detailed comparison to pmap.)

By reading this tutorial, you’ll learn how to use shard_map to get full control over your multi-device code. You’ll see in detail how it composes with jax.jit’s automatic parallelization and jax.grad’s automatic differentiation. We’ll also give some basic examples of neural network parallelization strategies.

We’ll assume this tutorial is being run in an environment with eight devices:

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices

So, let’s see a shard_map!#

Without further ado, here’s a toy example:

from functools import partial

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 *  4.).reshape(16, 4)

@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
         out_specs=P('x', None))
def matmul_basic(a_block, b_block):
  # a_block: f32[2, 8]
  # b_block: f32[8, 4]
  c_partialsum = jnp.dot(a_block, b_block)
  c_block = jax.lax.psum(c_partialsum, 'y')
  # c_block: f32[2, 4]
  return c_block

c = matmul_basic(a, b)   # c: f32[8, 4]

This function computes a matrix multiply in parallel by performing local block matrix multiplies followed by a collective sum operation. We can check the result is correct:

from jax.tree_util import tree_map, tree_all

def allclose(a, b):
  return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

allclose(c, jnp.dot(a, b))
True

The result is sharded along its rows:

jax.debug.visualize_array_sharding(c)
            
  CPU 0,1   
            
            
  CPU 2,3   
            
            
  CPU 4,5   
            
            
  CPU 6,7   
            

At a high level, shard_map is kind of like vmap or pmap, in that we’re mapping a function over pieces of array data, but notice that

  • shard_map slices up inputs into blocks (and the output is formed by concatenating result blocks), keeping the rank the same, whereas vmap would reduce the rank by mapping away an axis;

  • the mesh argument lets us control precise device placement of computation and results;

  • we’re mapping over multiple data axes at once, and setting up multiple axis names for collectives (both 'x' and 'y' here);

  • since we’re not using jax.jit yet, everything is eagerly evaluated, and we can even print intermediate values for debugging.

The above code is performing the same computation as this jax.jit automatic parallelization code:

from jax.sharding import NamedSharding

a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
b = jax.device_put(b, NamedSharding(mesh, P('y', None)))

@jax.jit
def matmul_reference(a, b):
  c = jnp.dot(a, b)
  return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))

c_ref = matmul_reference(a, b)
allclose(c_ref, jnp.dot(a, b))
True

We can think of shard_map as performing a device_put or with_sharding_constraint on its inputs according to its mesh and in_specs arguments, so the blocks over which matmul_basic operates are the same as in matmul_reference:

print('a blocks:'); jax.debug.visualize_array_sharding(a)
print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c)
a blocks:
b blocks:
c blocks:
                                                  
          CPU 0                    CPU 1          
                                                  
                                                  
          CPU 2                    CPU 3          
                                                  
                                                  
          CPU 4                    CPU 5          
                                                  
                                                  
          CPU 6                    CPU 7          
                                                  
           
           
CPU 0,2,4,6
           
           
           
           
           
CPU 1,3,5,7
           
           
           
            
  CPU 0,1   
            
            
  CPU 2,3   
            
            
  CPU 4,5   
            
            
  CPU 6,7   
            

Slow down, start with the basics!#

Rank-reducing vs rank-preserving maps#

We can think of vmap and pmap as unstacking each array input along an axis (e.g. unpacking a 2D matrix into its 1D rows), applying its body function to each piece, and stacking the results back together, at least when collectives aren’t involved:

def check_vmap(f, xs):
  ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
  expected = jnp.stack([f(x) for x in xs])  # vmap reference semantics
  print(allclose(ans, expected))

check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))
True

For example, if xs had shape f32[8,5] then each x would have shape f32[5], and if each f(x) had shape f32[3,7] then the final stacked result vmap(f)(xs) would have shape f32[8,3,7]. That is, each application of the body function f takes as argument inputs with one fewer axis than the corresponding argument to vmap(f). We can say these are rank-reducing maps with unstacking/stacking of inputs/outputs.

The number of logical applications of f, or instances of f, is determined by the size of the input axis being mapped over: for example, if we map over an input axis of size 8, semantically we get 8 logical applications of the function.

In contrast, shard_map does not have this rank-reducing behavior. Instead, we can think of it as slicing (or “unconcatenating”) along input axes into blocks, applying the body function, and concatenating the results back together (again when collectives aren’t involved):

import numpy as np
devices = np.array(jax.devices()[:4])
mesh = Mesh(devices, ('i',))  # mesh.shape['i'] = 4

def check_shmap(f, y):
  ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)
  expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
  print(allclose(ans, expected))

check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))
True

Recall that jnp.split slices its input into equally-sized blocks with the same rank, so that if in the above example y had shape f32[8,5] then each y_blk would have shape f32[2,5], and if each f(y_blk) had shape f32[3,7] then the final concatenated result shard_map(f, ...)(y) would have shape f32[12,7]. So shard_map maps over shards, or blocks, of its inputs. We can say it’s a rank-preserving map with unconcatenating/concatenating of its inputs/outputs.

The number of logical applications of f is determined by the mesh size, not by any input axis size: for example, if we have a mesh of total size 4 (i.e. over 4 devices) then semantically we get 4 logical applications of the function, corresponding to the 4 devices physically computing them.

Controlling how each input is split (unconcatenated) and tiled with in_specs#

Each of the in_specs identifies some of the corresponding input array’s axes with mesh axes by name using PartitionSpecs, representing how to split (or unconcatenate) that input into the blocks to which the body function is applied. That identification determines the shard sizes; when an input axis is identified with a mesh axis, the input is split (unconcatenated) along that logical axis into a number of pieces equal to the corresponding mesh axis size. (It’s an error if the corresponding mesh axis size does not evenly divide the input array axis size.) If an input’s pspec does not mention a mesh axis name, then there’s no splitting over that mesh axis. For example:

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('i', 'j'))

@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
  print(x_block.shape)  # prints (3, 12)
  return x_block

x1 = jnp.arange(12 * 12).reshape(12, 12)
y = f1(x1)
(3, 12)

Here, because the input pspec did not mention the mesh axis name 'j', no input array axis is split over that mesh axis; similarly, because the second axis of the input array is not identified with (and hence split over) any mesh axis, application of f1 gets a full view of the input along that axis.

When a mesh axis is not mentioned in an input pspec, we can always rewrite to a less efficient program where all mesh axes are mentioned but the caller performs a jnp.tile, for example:

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
  print(x_block.shape)
  return x_block

x = jnp.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.shape['j']))  # x_ has shape (12, 24)
y = f2(x_)  # prints (3,12), and f1(x) == f2(x_)
(3, 12)

In other words, because each input pspec can mention each mesh axis name zero or one times, rather than having to mention each name exactly once, we can say that in addition to the jnp.split built into its input, shard_map also has a jnp.tile built into its input, at least logically (though the tiling may not need to be carried out physically, depending on the arguments’ physical sharding layout). The tiling to use is not unique; we could also have tiled along the first axis, and used the pspec P(('j', 'i'), None).

Physical data movement is possible on inputs, as each device needs to have a copy of the appropriate data.

Controlling how each output assembled by concatenation, block transposition, and untiling using out_specs#

Analogously to the input side, each of the out_specs identifies some of the corresponding output array’s axes with mesh axes by name, representing how the output blocks (one for each application of the body function, or equivalently one for each physical device) should be assembled back together to form the final output value. For example, in both the f1 and f2 examples above the out_specs indicate we should form the final output by concatenating together the block results along both axes, resulting in both cases an array y of shape (12, 24). (It’s an error if an output shape of the body function, i.e. an output block shape, has a rank too small for the concatenation described by the corresponding output pspec.)

When a mesh axis name is not mentioned in an output pspec, it represents an un-tiling: when the user writes an output pspec which does not mention one of the mesh axis names, they promise that the output blocks are equal along that mesh axis, and so only one block along that axis is used in the output (rather than concatenating all the blocks together along that mesh axis). For example, using the same mesh as above:

x = jnp.array([[3.]])

z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()
print(z)  # prints the same as jnp.tile(x, (4, 2))

z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()
print(z)  # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))

z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()
print(z)  # prints the same as jnp.tile(x, (1, 1)), or just x
[[3. 3.]
 [3. 3.]
 [3. 3.]
 [3. 3.]]
[[3.]
 [3.]
 [3.]
 [3.]]
[[3.]]

The body function closing over an array value is equivalent to passing it as an augment with a corresponding input pspec of P(None, None). As another example, following more closely to the other examples above:

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
  return jax.lax.psum(x_block, 'j')

x = jnp.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)
(12, 6)

The result has a second axis size of 6, half the size of the input’s second axis. In this case, the un-tile expressed by not mentioning the mesh axis name 'j' in the output pspec was safe because of the collective psum, which ensures each output block is equal along the corresponding mesh axis. Here are two more examples where we vary which mesh axes are mentioned in the output pspec:

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
  return jax.lax.psum(x_block, 'i')

x = jnp.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape)  # (3,12)


@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
  return jax.lax.psum(x_block, ('i', 'j'))

y5 = f5(x)
print(y5.shape)  # (3,6)
(3, 12)
(3, 6)

On the physical side, not mentioning a mesh axis name in an output pspec assembles an Array from the output device buffers with replicated layout along that mesh axis.

There is no runtime check that the output blocks are actually equal along a mesh axis to be un-tiled along, or equivalently that the corresponding physical buffers have equal values and thus can be interpreted as a replicated layout for a single logical array. But we can provide a static check mechanism which raises an error on all potentially-incorrect programs.

Because the out_specs can mention mesh axis names zero or one times, and because they can be mentioned in any order, we can say that in addition to the jnp.concatenate built into its output, shard_map also has both an untile and a block transpose built into its output.

Physical data movement is not possible on outputs, no matter the output pspec. Instead, out_specs just encodes how to assemble the block outputs into Arrays, or physically how to interpret the buffers across devices as the physical layout of a single logical Array.

API Specification#

from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]

def shard_map(
    f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
    auto: collections.abc.Set[AxisName] = frozenset([]),
    check_rep: bool = True,
) -> Callable:
  ...

where:

  • communication collectives like psum in the body of f can mention the axis names of mesh;

  • mesh encodes devices arranged in an array and with associated axis names, just like it does for sharding.NamedSharding;

  • in_specs and out_specs are PartitionSpecs which can affinely mention axis names from mesh to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;

  • auto is an optional set of axis names corresponding to the subset of names of mesh to treat automatically in the body, as in the caller, rather than manually;

  • check_rep is an optional boolean indicating whether to check statically for any replication errors in out_specs, and also whether to enable a related automatic differentiation optimization (see JEP).

The shapes of the arguments passed to f have the same ranks as the arguments passed to shard_map-of-f, and the shape of an argument to f is computed from the shape shape of the corresponding argument to shard_map-of-f and the corresponding PartitionSpec spec as roughly tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec)).

Collectives tutorial#

A shard_map need not be a pure map: function applications can communicate with each other via collectives, using axis names defined in the mesh argument.

Recall that shard_map maps a function over shards, or blocks, of input data, so that this:

mesh = Mesh(jax.devices(), ('i',))
x = jnp.arange(16.)
f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))
y = f_shmapped(x)

Computes the same values, evaluating applications of f to the same argument values, as this reference function:

def f_shmapped_ref(x):
  x_blocks = jnp.array_split(x, mesh.shape[0])
  y_blocks = [f(x_blk) for x_blk in x_blocks]
  return jnp.concatenate(y_blocks)

We call these applications of f to different argument shards function instances. Each function instance is executed on a different device (or subset of devices).

These reference semantics work when f has no communication collectives in it. But what if we want the function instances to communicate, corresponding to having cross-device communication? That is, what are the reference semantics when f contains a collective? Say f has just one collective, and is of the form

def f(x_blk):
  z_blk = f_part1(x_blk)
  u_blk = collective(z_blk, axis_name)
  v_blk = f_part2(x_blk, z_blk, u_blk)
  return v_blk

where we’re assuming there’s only one mesh axis we’re mapping over, and axis_name is the corresponding name for it. Then the reference semantics would look more like:

def f_shmapped_ref(x):
  x_blocks = jnp.array_split(x, mesh.shape[0])
  z_blocks = [f_part1(x_blk) for x_blk in x_blocks]
  u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))]
  v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk
              in zip(x_blocks, z_blocks, u_blocks)]
  return jnp.concatenate(v_blocks)

Notice that collective_ref might depend on all the z_blocks. That is, while f_part1 and f_part2 are mapped over blocks independently, a collective introduces some amount of cross-block dependence. Physically, that means communication across devices. Exactly what communication happens, and what values are computed, depend on the collective.

psum#

The simplest collective may be jax.lax.psum, which computes an all-reduce-sum along a device mesh axis (or multiple axes). Here’s a toy example:

Illustration of a psum computation.
import jax
import jax.numpy as jnp
from jax import lax

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh1d = Mesh(jax.devices()[:4], ('i',))

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))
def f1(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, 'i')
  print('AFTER:\n', y_block)
  return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f1(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22 20 12 17]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[22 20 12 17]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[22 20 12 17]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[22 20 12 17]

FINAL RESULT:
 [22 20 12 17]

The prints show that each function application starts with its own chunk of the argument value x_block. After the psum, each function application has the same value of y_block, computed by summing the applications’ x_block values together.

In the case where there’s a single axis name in the computation, we could say that the collective_ref reference implementation for psum is

def psum_ref(_, x_blocks):
  tot = sum(x_blocks)
  return [tot] * len(x_blocks)

Notice also that because f1 returns y_block, the result of a psum over 'i', we can use out_specs=P() so the caller gets a single logical copy of the result value, rather than a tiled result.

When there is more than one mesh axis, we can perform a psum over each one separately, or over multiple axes at once:

mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))

@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f2(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, 'i')
  print('AFTER:\n', y_block)
  return y_block

y = f2(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
 [4 5]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
 [6 7]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8  9]
 [12 13]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
 [14 15]]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[ 8 10]
 [16 18]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[12 14]
 [20 22]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 10]
 [16 18]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[12 14]
 [20 22]]

FINAL RESULT:
 [[ 8 10 12 14]
 [16 18 20 22]]

By applying a psum over mesh axis 'i', we get values of y_block which are equal along axis ‘i', but not axis 'j'. (So we can use out_specs=P(None, 'j') to get a single logical result along that axis.)

If we apply the psum over both axes, the y_block value is equal along both axes:

@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))
def f3(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, ('i', 'j'))
  print('AFTER:\n', y_block)
  return y_block

y = f3(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
 [4 5]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
 [6 7]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8  9]
 [12 13]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
 [14 15]]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[20 24]
 [36 40]]

On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[20 24]
 [36 40]]

On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[20 24]
 [36 40]]

On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[20 24]
 [36 40]]

FINAL RESULT:
 [[20 24]
 [36 40]]

In machine learning, we often use psum to compute total losses or, when we have a grad inside the shard_mapped function body, total gradients.

In the sequel, we’ll see how psum can be implemented in terms of other primitives, which gives some intuition about its communication cost.

all_gather#

Another fundamental operation is gathering array shards along an axis, so that each function application has a full copy of the data along that axis:

Illustration of an all_gather computation.
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f4(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_gather(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 9, 5, 2])
y = f4(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 9 5 2]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[3 9 5 2]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[3 9 5 2]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[3 9 5 2]

FINAL RESULT:
 [3 9 5 2 3 9 5 2 3 9 5 2 3 9 5 2]

The prints show that each function application again starts with its own chunk of the argument value x_block. After the all_gather, they have a common value, computed by concatenating the values of x_block.

(Notice that we actually can’t set out_specs=P() here. For technical reasons related to automatic differentiation, we consider the output of all_gather not to be guaranteed invariant across devices. If we wanted it to be guaranteed invariant, we could use jax.lax.all_gather_invariant, or in this case we could just avoid doing the all_gather in the function body and instead just use out_specs=P('i') to perform the concatenation.)

When tiled=False (the default), results are stacked along a new axis instead of concatenated:

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f5(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_gather(x_block, 'i', tiled=False)
  print('AFTER:\n', y_block)
  return y_block

y = f5(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[9]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[[3]
 [9]
 [5]
 [2]]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[[3]
 [9]
 [5]
 [2]]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[[3]
 [9]
 [5]
 [2]]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[[3]
 [9]
 [5]
 [2]]

FINAL RESULT:
 [[3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]
 [3]
 [9]
 [5]
 [2]]

We could write the collective_ref reference semantics function for all_gather as

def all_gather_ref(_, x_blocks, *, tiled=False):
  combine = jnp.concatenate if tiled else jnp.stack
  return [combine(x_blocks)] * len(x_blocks)

In deep learning, we might use all_gathers on parameters in fully sharded data parallelism (FSDP).

psum_scatter#

The jax.lax.psum_scatter collective is a bit less intuitive. It’s like psum except each function instance gets only one shard of the result:

Illustration of a psum_scatter computation.
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f6(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f6(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]

FINAL RESULT:
 [22 20 12 17]

As shown by the prints, each resulting y_block has a smaller size than the argument x_block, unlike with psum. Moreover, compared to psum, here each y_block only represents a slice of the sum of the x_blocks across function instances. (Even though each function instance gets only one shard of the sum, the final output y is the same as in the psum example because here we use out_specs=P('i') to concatenate each function instance’s output.)

In terms of what values are computed, a collective_ref reference implementation might look like:

def psum_scatter_ref(i, x_blocks, *, tiled=False):
  axis_size = len(x_blocks)
  tot = sum(x_blocks)
  if tiled:
    tot = tot.reshape(axis_size, -1, *tot.shape[1:])  # split leading axis
  return [tot[i] for i in range(tot.shape[0])]

It’s not captured in the semantics reference implementation, but psum_scatter is useful because these results can be computed more efficiently, with less communication, than a full psum. In fact, one way to think of psum_scatter is as “the first half of a psum, before an all_gather”. That is, one way to implement psum is:

def psum(x, axis_name):
  summed_chunk = jax.lax.psum_scatter(x, axis_name)
  return jax.lax.all_gather(summed_chunk, axis_name)

Indeed, this implementation is often used on both TPU and GPU!

The reason psum_scatter can require about half the communication as a full psum is illustrated the ppermute section.

Another intuition is that we can use psum_scatter to implement a distributed matrix multiplication with inputs and outputs sharded over the same axis. In machine learning, psum_scatter can be used in tensor-parallel matrix multiplies or fully-sharded data parallel gradient accumulation, as shown in the examples to follow.

ppermute#

The jax.lax.ppermute collective provides the most direct way for function instances to send data to one another. Given a mesh axis and a list of (source_index, destination_index) pairs representing indices along that mesh axis, ppermute sends its argument value from each source function instance to each destination:

@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f7(x_block):
  sz = jax.lax.psum(1, 'i')
  print('BEFORE:\n', x_block)
  y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])
  print('AFTER:\n', y_block)
  return y_block

y = f7(jnp.arange(8))
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[0 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[2 3]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 5]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[6 7]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[6 7]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[0 1]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[2 3]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[4 5]

FINAL RESULT:
 [6 7 0 1 2 3 4 5]

In this case, with just two function instances, each instance’s value of y_block is the other’s value of x_block.

Source indices and destination indices can’t be repeated. If an index does not appear as a destination, then the value of the corresponding function instance’s result is an array of zeros.

A collective_ref reference implementation could look like

def ppermute_ref(i, x_blocks, perm):
  results = [jnp.zeros_like(x_blocks[0])] * len(x_blocks)
  for src, dst in perm:
    results[dst] = x_blocks[src]
  return results

Other collectives can be implemented efficiently, in terms of total communication, using ppermutes where each function passes data only to its neighbors. For example, we could implement psum_scatter using a sequence of ppermutes and local additions this way:

Illustration of a psum_scatter implementation.

Or, with a numerical example:

Illustration of a psum_scatter implementation.

Intuitively, on each iteration each function instance sends ‘up’ the value it received on the previous iteration, and reduces (adds) the value it receives this iteration. In code, it might look like this:

def psum_scatter(x, axis_name, *, tiled=False):
  size = jax.lax.psum(1, axis_name)
  idx = jax.lax.axis_index(axis_name)  # function instance index along axis_name
  if tiled:
    x = x.reshape(size, -1, *x.shape[1:])  # split leading axis
  shift = partial(jax.lax.ppermute, axis_name=axis_name,
                  perm=[(i, (i - 1) % size) for i in range(size)])
  for i in range(1, size):
    update = shift(x[(idx + i) % size])
    x = x.at[(idx + i + 1) % size].add(update)
  return x[idx]
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f8(x_block):
  print('BEFORE:\n', x_block)
  y_block = psum_scatter(x_block, 'i', tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f8(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[20]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[12]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[17]

FINAL RESULT:
 [22 20 12 17]

On TPU, there are higher-dimensional variants of this algorithm to exploit multiple bidirectional physical mesh axes.

Notice that psum_scatter is the transpose of all_gather. Indeed, a way to implement all_gather in terms of ppermute looks like the reverse of the above process:

Illustration of an all_gather implementation.

In deep learning, we might use ppermute when implementing SPMD pipeline parallelism, where we divide our network along its depth into stages and evaluate the applications of stages in parallel. Or we might use ppermute in parallelizing the evaluation of convolutional layers, where we shard over spatial axes and thus devices must communicate “halos” to each other. Or it may be used under-the-hood in tensor-parallel matrix multiplies.

all_to_all#

A final collective is all_to_all, which is essentially a block matrix transpose operating along one positional axis and one cross-device axis:

Illustration of an all_to_all computation.
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))
def f9(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,
                               tiled=True)
  print('AFTER:\n', y_block)
  return y_block

x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f9(x)
print('FINAL RESULT:\n', y)
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]

AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 5 5 9]

On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[1 9 3 7]

On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[4 2 5 1]

On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[1 6 8 2]

FINAL RESULT:
 [3 5 5 9 1 9 3 7 4 2 5 1 1 6 8 2]

The split_axis argument indicates which positional axis should be sharded and partitioned across the mesh axis. The concat_axis argument indicates the axis along which the communicated results should be concatenated or stacked.

When tiled=False (the default), the split_axis axis size must equal the size of the mesh axis named axis_name, and a new axis of that size is created at position concat_axis for the stacked results. When tiled=True, the split_axis axis size need only be evenly divisible by the size of the mesh axis, and results are concatenated along the existing axis concat_axis.

The collective_ref reference semantics when split_axis=0 and concat_axis=0 might look like:

def all_to_all_ref(_, x_blocks, *, tiled=False):
  axis_size = len(x_blocks)
  if tiled:
    splits = [jnp.array_split(x, axis_size) for x in x_blocks]
    return [jnp.concatenate(s) for s in zip(*splits)]
  else:
    splits = [list(x) for x in x_blocks]
    return [jnp.stack(s) for s in zip(*splits)]

In deep learning, we might use all_to_all in mixture-of-expert routing, where we first sort our local batch of examples according to which expert they should go to, then apply an all_to_all to redistribute examples to experts.

Toy examples#

How might we use shard_map and collective communication in practice? These examples, while simple, give some idea.

Matrix multiplies#

Parallelizing matrix multiplication is central in scaling up deep learning models, both for training and for inference. When jax.jit automatically parallelizes matrix multiplication, it can use one of several different strategies, depending on matrix sizes, hardware details, and other factors. How might we write some of those parallelized routines more explicitly using shard_map? And how can we optimize them to get better compute/communication overlap and thus improve FLOP utilization?

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = Mesh(jax.devices()[:4], ('i',))

def device_put(x, pspec):
  return jax.device_put(x, NamedSharding(mesh, pspec))

Example 1: all-gather on one side#

Consider performing a matrix multiplication where we shard the left-hand side argument (can think: parameters) on its leading (non-contracting) dimension:

lhs_spec = P('i', None)
lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)

And wee shard the right-hand side argument (can think: activations) on its contracting dimension, with a similar sharding for the output:

rhs_spec = P('i', None)
rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)

To perform this matrix multiplication, we can first all-gather the right-hand side and then perform local matrix multiplies against the sharded left-hand side:

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather(lhs_block, rhs_block):
  rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)
  return lhs_block @ rhs
out = matmul_allgather(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

That’s great, but we’re not getting any compute/communication overlap here: before we can start the matmul, we need the all_gather to complete. Here’s a profile using the same code, but on larger example shapes ((8192, 8192) for lhs and (8192, 1024) for rhs):

Profile of an all-gather matmul without overlap.

We can get compute/communication overlap if instead of calling all_gather we basically inline our above implementation of all_gather in terms of ppermute, then interleave steps of the gather permutation with local matrix multiplies:

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather_overlapped(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift = partial(jax.lax.ppermute, axis_name='i',
                  perm=[(i, (i + 1) % size) for i in range(size)])

  B = lhs_block.shape[1] // size
  lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)

  out_block = lhs_blocks(idx) @ rhs_block
  for i in range(1, size):
    rhs_block = shift(rhs_block)
    out_block += lhs_blocks((idx - i) % size) @ rhs_block
  return out_block
out = matmul_allgather_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

This implementation allows overlap between communication and computation, and also avoids gathering a large intermediate onto each device. But on TPU it uses only half the interconnect bandwidth by permuting in only one direction along the ring. To permute bidirectionally, we just split the blocks in half and send each half in each direction:

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift_up = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i + 1) % size) for i in range(size)])
  shift_dn = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i - 1) % size) for i in range(size)])

  B = lhs_block.shape[1] // size // 2  # half-size blocks
  lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1)

  rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)
  out_block  = lhs_blocks(idx, 0) @ rhs_block_lo
  out_block += lhs_blocks(idx, 1) @ rhs_block_hi
  for i in range(1, size):
    rhs_block_lo = shift_up(rhs_block_lo)
    rhs_block_hi = shift_dn(rhs_block_hi)
    out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo
    out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi
  return out_block
out = matmul_allgather_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
Profile of an all-gather matmul with overlap.

In practice, to reduce compile times we would probably roll this into a jax.lax.fori_loop. We might also have additional axes of parallelism involved.

Example 2: psum_scatter the result#

Another sharding we might start with has both lhs and rhs sharded along their contracting dimensions, with the output sharded like rhs again:

lhs_spec = P(None, 'i')
lhs = device_put(lhs, lhs_spec)

rhs_spec = P('i', None)
rhs = device_put(rhs, rhs_spec)

Here we can use a reduce_scatter to perform the contraction sum over shards:

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter(lhs_block, rhs_block):
  out_summand = lhs_block @ rhs_block
  return jax.lax.psum_scatter(out_summand, 'i', tiled=True)

out = matmul_psumscatter(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

But the scattering communication must wait for the entire local matrix multiply to finish before it can start. To get communication/computation overlap, we can inline an implementation of psum_scatter in terms of ppermute, then interleave the communication steps with local matrix multiplies:

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter_overlapped(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift = partial(jax.lax.ppermute, axis_name='i',
                  perm=[(i, (i - 1) % size) for i in range(size)])
  lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1])  # split 1st axis

  out_summand = lhs_block[(idx + 1) % size] @ rhs_block
  for i in range(1, size):
    out_summand = shift(out_summand)
    out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block
  return out_summand
out = matmul_psumscatter_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

As in the previous example, to fully utilize interconnects on TPU, we’d run a bidirectional version:

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift_up = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i + 1) % size) for i in range(size)])
  shift_dn = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i - 1) % size) for i in range(size)])

  B = lhs_block.shape[0] // size // 2  # half-size blocks
  lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0)

  out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block
  out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block
  for i in range(1, size):
    out_summand_lo = shift_up(out_summand_lo)
    out_summand_hi = shift_dn(out_summand_hi)
    out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block
    out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block
  return jnp.concatenate([out_summand_lo, out_summand_hi])
out = matmul_psumscatter_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True

Neural networks#

We can use shard_map to parallelize computation in neural networks, either by itself or in combination with the automatic partitioning in jax.jit. This section has a few examples based on this toy neural network and random data:

import jax
import jax.numpy as jnp

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jax.nn.relu(outputs)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
def init_layer(key, n_in, n_out):
    k1, k2 = jax.random.split(key)
    W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
    b = jax.random.normal(k2, (n_out,))
    return W, b

def init(key, layer_sizes, batch_size):
    key, *keys = jax.random.split(key, len(layer_sizes))
    params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

    key, *keys = jax.random.split(key, 3)
    inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
    targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

    return params, (inputs, targets)
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32

params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)

Compare these examples with the purely automatic partitioning examples in the “Distributed arrays and automatic partitioning” doc. While in those automatic partitioning examples we don’t need to edit the model functions to use different parallelization strategies, with shard_map we often do.

8-way batch data parallelism#

The simplest multi-device parallelism strategy is to shard the batch of inputs and targets over multiple devices, replicate the parameters over those devices, and apply the model in parallel to those shards of data. To evaluate the total loss, the devices need only communicate with a scalar-sized all-reduce-sum at the end. (To evaluate the gradient of the loss, the devices must perform all-reduce-sums of parameter gradients in the backward pass.)

from functools import partial

from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.experimental import mesh_utils

devices = mesh_utils.create_device_mesh((8,))

# replicate initial params on all devices, shard data batch over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P()))

# adapt the loss function to sum the losses across devices
def loss_dp(params, batch):
  @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())
  def loss_spmd(local_batch):
    inputs, targets = local_batch
    predictions = predict(params, inputs)  # use reference 'predict`
    local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
    return jax.lax.pmean(local_loss, 'batch')
  return loss_spmd(batch)

We can check that the loss and its gradients match the reference (base) model:

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_dp)(params, batch))
22.779888
22.779888
def allclose(a, b):
  return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))

print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_dp))(params, batch)))
True

We can print the compiler IR to inspect the gradient computation and verify that the collective all-reduce-sum operations happen where we’d expect: at the end of the forward pass to compute the loss value, and in the backward pass to compute the total parameter gradients.

8-way fully sharded data parallelism (FSDP)#

Another strategy is to additionally shard the parameters over the devices, all-gathering each one when the full value is needed for the jnp.dot or bias addition. Since we only have one full parameter in local device memory at a time, rather than keeping all parameters in all device memories as in the preceding DP example, we free up significant memory that we can use for larger models or larger batch sizes. And because XLA will overlap computation and inter-device communication, the wall-clock time doesn’t suffer.

So now we need collectives in two places: the model prediction function predict needs to all-gather the parameters before they’re used, and as in the DP case the loss function needs to sum the local losses to compute the total loss.

There’s one other ingredient we need: we don’t want to store the fully gathered parameters from the forward pass for use on the backward pass. Instead, we want to gather them again on the backward pass. We can express that by using jax.remat with a custom policy (or a custom_vjp), though XLA typically does that rematerialization automatically.

This general FSDP approach is similar to weight update sharding (WUS) and ZeRO-3.

# shard data batch *and params* over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P('batch')))

# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp(params_frag, inputs):
  for W_frag, b_frag in params_frag:
    W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
    b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
    outputs = jnp.dot(inputs, W) + b
    inputs = jax.nn.relu(outputs)
  return outputs

def loss_fsdp(params, batch):
  @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())
  def loss_spmd(local_params, local_batch):
    inputs, targets = local_batch
    predictions = predict_fsdp(local_params, inputs)
    local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
    return jax.lax.pmean(local_loss, 'batch')
  return loss_spmd(params, batch)

Again we can check that the loss and its gradients match the reference model:

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp)(params, batch))

print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_fsdp))(params, batch)))
22.779888
22.779888
True

8-way tensor parallelism (TP)#

Usually we don’t use tensor model parallelism by itself, but seeing it in isolation is a good warmup on parallel matrix multiplication. It’s also a good example of using shard_map in a library function, called in a larger jit-based computation.

The parallelization idea is that we’ll keep the data/activations sharded over its feature axis (rather than its batch axis), and we’ll similarly shard weight matrices over their input-feature axis (and biases over their feature axis). Then to perform the parallel matrix multiplication, we’ll perform local matrix multiplications followed by a psum_scatter to sum the local results and efficiently scatter the result’s shards.

devices = mesh_utils.create_device_mesh((8,))
mesh = Mesh(devices, ('feats',))

batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
params = jax.device_put(params, NamedSharding(mesh, P('feats')))

def predict_tp(params, inputs):
  for W, b in params:
    outputs = gemm_tp(inputs, W, b)
    inputs = jax.nn.relu(outputs)
  return outputs

@partial(shard_map, mesh=mesh,
         in_specs=(P(None, 'feats'), P('feats', None), P('feats')),
         out_specs=P(None, 'feats'))
def gemm_tp(inputs, W, b):
  block_result = jnp.dot(inputs, W)
  return jax.lax.psum_scatter(block_result, 'feats',
                              scatter_dimension=1, tiled=True) + b

def loss_tp(params, batch):
  inputs, targets = batch
  predictions = predict_tp(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1))  # NOTE psum!

FSDP + TP, with shard_map at the top level#

We can compose these strategies together, using multiple axes of parallelism.

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('batch', 'feats'))

batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))
params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))

# mostly same as previous predict_fsdp definition, except we call gemm_tp
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp_tp(params_frag, inputs):
  for W_frag, b_frag in params_frag:
    W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
    b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
    block_result = jnp.dot(inputs, W)
    outputs = jax.lax.psum_scatter(block_result, 'feats',
                                   scatter_dimension=1, tiled=True) + b
    inputs = jax.nn.relu(outputs)
  return outputs

@partial(shard_map, mesh=mesh,
         in_specs=(P(('feats', 'batch')), P('batch', 'feats')),
         out_specs=P())
def loss_fsdp_tp(local_params, local_batch):
  inputs, targets = local_batch
  predictions = predict_fsdp_tp(local_params, inputs)
  sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats')
  return jax.lax.pmean(jnp.mean(sq_err), 'batch')

Notice how we have to do two collective reductions: one over 'feats' and one over 'batch'. In the pure TP example, we didn’t write the 'feats' reduction explicitly because we only used shard_map within gemm_tp; in the caller loss_tp, the compiler automatically translated our use of jnp.sum to perform a psum as needed given the sharded result returned by predict_tp.

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp_tp)(params_, batch_))

print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))
22.779886
22.779886
True

SPMD pipeline parallelism (PP)#

With pipeline parallelism we aim to parallelize the evaluation of layers at different depths in our network. For example, one device might compute the application of the first layer while another device computes the application of the second; when they finish, the first device passes its results to the second while the second passes its results to the device responsible for the third layer, and the process repeats. In general the number of pipeline stages may be different from the number of layers, as each stage may be responsible for multiple layers.

With SPMD pipelining, we exploit the fact that most layers in the network apply the computation, just with different parameter values. In particular, we can stack together all the parameters except for those for the first and last layers, then use a shard_map to map over blocks of those layer parameters, where each block of parameters corresponds to a pipeline stage. We then use the jax.lax.ppermute collective to shift data down the parallel pipeline.

This particular pipelining strategy is essentially the GPipe strategy. There are several variants, as well as quite different strategies, and which is appropriate can depend on the speed of the networking between stages and batch sizes. But for this tutorial we’ll focus on just one strategy.

First, we choose some pipeline parameters:

L = len(params) - 2        # num layers, excluding first and last
N = batch_size             # batch size
F = params[0][0].shape[1]  # num features

# choose some pipeline parameters
S = 2      # number of stages
B = 8      # size of each microbatch
assert L % S == 0, "S (number of stages) must divide L (number of inner layers)"

# compute some useful quantities
M, ragged = divmod(N, B)  # M is number of microbatches
assert not ragged, "B (size of each microbatch) must divide total batch size"
K, ragged = divmod(M, S)  # K is microbatches per stage
assert not ragged, "S (number of stages) must divide number of microbatches"
print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total')
print(f'{B} examples per microbatch, {M} microbatches total')
2 stages, 2 layer(s) per stage, 4 pipelined layers total
8 examples per microbatch, 4 microbatches total
mesh = Mesh(jax.devices()[:S], ('stages',))

def predict_pp(params, inputs):
  (W_first, b_first), inner_params, (W_last, b_last) = params
  inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first)
  inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]),
                        inner_params, inputs)
  outputs = jnp.dot(inputs, W_last) + b_last
  return outputs

@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),
         out_specs=P())
def loss_pp(params, batch):
  inputs, targets = batch
  predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)
  local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
  return jax.lax.pmean(local_loss, 'stages')
def spmd_pipeline(fn, stage_params, inputs):
  stage = jax.lax.axis_index('stages')
  outputs = jnp.zeros_like(inputs) * jnp.nan
  state = jnp.zeros((L // S, B, F)) * jnp.nan
  for i in range(M+L-1):
    state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0]))
    state = jax.vmap(fn)(stage_params, state)
    outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K]))
    state, inputs, outputs = shift(i, state, inputs, outputs)
  outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)])
  return outputs

def shift(i, state, inputs, outputs):
  sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)])
  state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1))
  if (i % K) == (-1 % K):
    inputs = sh(inputs, +1)
  if ((i-L+1) % K) == (-1 % K):
    outputs = sh(outputs, +1)
  return state, inputs, outputs
first_params, *inner_params, last_params = params
Ws, bs = zip(*inner_params)
params_stacked = jnp.stack(Ws), jnp.stack(bs)
first_params = jax.device_put(first_params, NamedSharding(mesh, P()))
params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages')))
last_params = jax.device_put(last_params, NamedSharding(mesh, P()))
params_ = first_params, params_stacked, last_params

batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_pp)(params_, batch_))
22.779886
22.779884
_ = jax.jit(jax.grad(loss_pp))(params_, batch_)   # don't crash