Efficient transposition of replication-inducing collectives#

mattjj@, dougalm@

August 2023


We have an efficiency problem in automatically transposing shmaps containing certain collectives. The issue arises with psum and all_gather, specifically when the output of the collective is returned to the caller as an unmapped output. And it’s not an edge case: for example, it arises when applying grad to a shmap-based batch data parallel neural network loss function which uses psum to compute the total loss.

We’ve known about this problem for some time. An analogous issue exists with pmap, though it’s been worked around by keeping grad inside pmap rather than outside. A primary goal of the incomplete avals-with-names work was to address a version of this transpose efficiency problem. This doc draws on those ideas, while extending and revising them to handle more cases and to be much easier to land. Indeed the solution proposed here only affects the shmap implementation. The rest of the system need not be changed (yet).

The main purpose of this doc is to define this transpose efficiency problem and propose an easy-to-land solution.

This doc is not about:

  • logical axis names on arrays (the only axis names here are just like in shmap and OG pmap);

  • changing autodiff semantics (all the numbers and (non)errors are staying the same, we’re just making things more efficient);

  • allowing user code to reflect on any new information, or really affecting user code at all.

Problem: efficient transpose of psum or all_gather depends on whether cotangents are invariant across devices#

Consider this semi-realistic example, meant to resemble a replicated-parameter batch data parallel loss function:

devices = jax.devices()  # 8 devices

@partial(shmap, mesh=Mesh(devices, ('batch',)),
         in_specs=(P(None, None), P('batch', None)),
def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
  global_loss = lax.pmean(local_loss, 'batch'))
  return global_loss

Notice the out_specs=P(), which indicates an unmapped output. If you’re not familiar with the notion of unmapped outputs, see the appendix at the bottom of this document.

Most of the details in the loss example aren’t important. All that matters for our purposes is that we’re applying psum (or rather pmean = lambda x, name: psum(x, name) / psum(1, name)) at the end. So a distilled version looks like this:

# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

We even simplified notation by suppressing the mesh argument. In the examples to follow it can be inferred from context.

What does the transpose look like? Writing t to mean function transpose, we could evaluate t(f1)(ybar) for any ybar efficiently by applying the function ¿f1_transpose? below:

# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))

But that’s not the transpose we currently get as t(f1).

Instead, the current recipe for transposition is roughly that we switch in_specs and out_specs, do some division rescaling for unmapped outputs, and transpose the body. Because psum is its own transpose (as an all-reduce sum), we end up producing this transpose:

# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
              in_specs=P(), out_specs=P('i'))

This transpose gets the numbers right, but it’s wasteful. We know statically from the transpose’s in_specs=P() that ybar has the same value for each function instance, i.e. that its value is device-invariant for devices along the mesh axis named i, and yet we apply a psum to it! That uses expensive communication just to multiply the value on each device by 8. (Here 8 refers to the size of axis i. The division by 8 comes from the original function’s out_specs=P(); it and the trivial psum basically cancel each other out.)

What are we doing wrong? We’re not exploiting the fact that cotangents ybar corresponding to f1’s unmapped outputs are guaranteed to be device-invariant; instead, we’re defensively psumming them as if they weren’t because psum’s transpose can’t be sure given the local information it has. Sometimes the psum is necessary, as in transposing f2 with respect to its first argument:

# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
          in_specs=(P('i'), P('i')), out_specs=P('i'))

# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
                in_specs=(P('i'), P('i')), out_specs=P('i'))

Intuitively, if our transpose machinery could tell the difference between Example 1 and Example 2, we could do better by avoiding the psum and division where possible.

The inefficient examples can be even smaller. Consider transposing this cursed identity function:

# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())

# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())

It keeps getting bigger the more we transpose. How embarrassing!

And psum isn’t the only culprit. Something analogous holds true for all_gather:

# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())

# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))

This program is a bit artificial. Why do an all_gather and feed the result into an unmapped output, rather than skipping the all_gather in the body and just using out_specs=P('i') to collect the results? But even though it’s cooked-up, this example nevertheless exhibits a transpose which unnecessarily performs communication (we could have just performed a non-communicating slice), analogous to Example 1 for psum.

Also analogously to the psum examples, the defensive psum_scatter is necessary in some cases:

# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))

# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))

So how do we avoid these inefficient transposes?


Here are two solution ideas. They aren’t mutually exclusive. But (spoilers) the second one is better, and it’s all we need.

Partial solution “P-sum”: build the ability to express a psum into out_specs#

This solution is a bit of a strawperson because it would offer only an awkward way to write programs. And it wouldn’t even fix everything! But it’s worth considering, if only to motivate a more complete solution.

Example 4 above is artificial because we could have just used out_specs instead of an all_gather in the body:

# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())

# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))

The f4_better version doesn’t have any transposition problems, since the transpose problems arise from collectives in the body.

Analogously, we could fix Example 1 by extending out_specs so that they can express summing:

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i'))  # sum='i' means sum over that axis

# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))

So offering psums built into out_specs fixes the transpose problem of Example 1. But it doesn’t fully fix the cursed identity transpose in Example 3:

# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())

# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))

It’s an improvement since the program doesn’t continue to get bigger as we keep transposing, but we’re still doing wasteful communication.

Full solution: statically track device-varying vs device-invariant intermediates, plus new primitives#

This solution has two components:

  1. track when values are guaranteed to be device-invariant vs device-varying over particular mesh axes, and

  2. decompose psum into a two-step process, introducing a new pbroadcast primitive, and introduce new primitives for all_gather and its transposes.

Morally, the tracking of device-invariant vs device-varying information is a type-level consideration. But for the expedience of our first implementation, we don’t need to literally add the information to abstract values or jaxpr types. Before we get to implementation, we’ll first introduce the idea using types.

Also to follow is a discussion of making the user API convenient and backward compatible. But to first introduce the idea, we’ll ignore convenience and instead write code that is as explicit as possible.

Tracking device invariance in avals (a.k.a. avals-with-names, revived)#

We can sometimes tell from static information alone that the values of some intermediate variables in the body of a shmap are guaranteed to be invariant along a mesh axis, in the sense that the function instances (and their corresponding devices) along the mesh axis must all be computing with the same value. We’ll call such values device-invariant. For values that are not device-invariant, we’ll say they’re device-varying, though really we mean potentially device-varying from the point of view of the type system.

To encode device variance in types, we’ll extend the syntax of types for arrays. We’ll write things like x:f32[3,4]{i} to indicate that x is (potentially) device-varying along mesh axis i (and device-invariant over any other mesh axes of the shmap). More generally, we’ll say the grammar for array type syntax is something like

shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}

We’ll also update the typing rules to handle device variance types:

  • for first-order primitives other than collectives

    • for multi-arity primitives, the operand device variance types must be equal where shapes must be equal, e.g. mul x:f32[s1]{r1} y:f32[s2][r2] requires r1 == r2 in addition to s1 == s2

    • the output device variance type must be the same as the operand(s)

  • for higher-order primitives

    • we just instantiate any type variables including the device variance type (and checking types for equality checks their device variance types are equal)

    • (when performing type inference, e.g. for branches of a cond, we take the union of the sets of axis names in device variance types)

  • for first-order collectives

    • a collective can either accept a device-varying or device-invariant input (along a mesh axis corresponding to its axis name parameter); it’s an error to pass a device-invariant operand to a collective which accepts device-varying operands and vice-versa

    • a collective can either produce a device-varying or device-invariant output

    • see the table below As a side benefit, whatever logic implements this type checking can subsume shmap’s “static analysis” check for whether a shmap body function is compatible with any unmapped out_specs.

Here’s a table summarizing the device variance typing for collective primitives:


Device variance type


Lowers to HLO



Varying -> Invariant

y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')

AllReduceSum (communication)



Invariant -> Varying

y:f32[3]{i} = pbroadcast(x:f32[3], 'i')

no-op (no communication)



Varying -> Varying

y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0) AllToAll (communication)



() -> Varying

idx:i32[]{i} = axis_index('i')

ReplicaId and some arithmetic (no communication)



Varying -> Varying

y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')

ReduceScatterSum (communication)



Varying -> Varying

y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')

AllGather (communication)



Invariant -> Varying

y:f32[2]{i} = pscatter(x:f32[16], 'i')

lambda x: x[axis_index('i'), None] (no communication)



Varying -> Invariant

y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')

AllGather (communication)


There are some surprising things here!

  • We introduced several new primitives, including

    • pbroadcast, which interestingly lowers to a no-op

    • all_gather_invariant, which lowers to the same thing as all_gather but has a different device variance type (essentially all_gather has a pbroadcast fused into it, whereas all_gather_invariant does not)

    • pscatter which is the dual (transpose) of all_gather_invariant

  • all_gather has a device-varying result

Intuitively, the reason to introduce pbroadcast (other than to make the typing rules work) is so that psum can transpose to a physical no-op. The reason we need all_gather to have a device-varying result is so that we can transpose it to psum_scatter; if we instead left it with a device-invariant result, we might need a downstream pbroadcast, and that composition would transpose to an inefficient psum followed by slicing / pscatter. So instead we have a pbroadcast “fused into” the all_gather, thus allowing for an efficient transpose into psum_scatter. We provide all_gather_invariant and its transpose pscatter mainly for completeness; it’s unlikely users will need it (it corresponds to the situation in Example 4, which is easy to write differently using out_specs).

Interestingly, the psum and pbroadcast transpose pair correspond to the psum_idrev and id_psumrev that users introduced while training LLMs with pmap.

How this system solves the inefficient transpose examples#

Consider again the simplified motivating example:

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
  w:f32[]{i} = g(x)
  y:f32[]{} = psum(w, 'i')
  return y

With these new rules, the transpose is:

# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
              in_specs=P(), out_specs=P('i'))

# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
  wbar:f32[]{i} = pbroadcast(ybar, 'i')
  xbar:f32[3,4]{i} = transpose(g)(wbar)
  return xbar

where evaluating the pbroadcast application involves no communication or FLOPs at all; it’s a no-op. Notice that if we keep transposing the body does not grow in size; indeed t(t(f1)) == f1. Efficiency achieved!

And we wouldn’t mess up the other examples either, so long as we pbroadcast to make the types check where needed:

# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))

# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))

# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.

# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())

Intuitively, in Example 1 we now only have “half the original psum”, whereas in Example 2 we get both “halves”. For Example 3 we never need any operations in the body at all.

For the all_gather examples, Example 4 would need to use all_reduce_invariant to have an efficient transpose (though it’d be better to instead use out_specs instead of the collective in the body):

# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())

# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
  y:f32[8]{} = all_gather_invariant(x, 'i')
  return y

# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
  xbar:f32[1]{i} = pscatter(ybar, 'i')
  return xbar

For Example 5, using the device-varying all_gather works as we’d want:

# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
  z:f32[8]{i} = all_gather(x, 'i')
  w:f32[8]{i} = z * y
  return w

# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
  zbar:f32[8]{i} = wbar * y
  xbar:f32[1]{i} = psum_scatter(zbar, 'i')
  return xbar

How to make the API convenient for users (and backward compatible)#

But what user wants to write pbroadcasts? And what developer wants to break lots of existing user code involving psums which are not fed into unmapped outputs? Not me!

Instead we can automatically insert the pbroadcasts. It’s a bit analogous to how we do automatic rank promotion at the jax.numpy layer, inserting broadcasts to avoid rank mismatch errors in binary operators. But it’s much simpler since we don’t need to contend with shape tuples. The typical rule is: whenever we see a multi-arity operation where the operands disagree in their device variance types, take the union of operands’ device variance types’ axis name sets and insert pbroadcasts to lift each operand to the resulting device variance type.

Automatically inserting pbroadcasts just before they’re needed may mean we apply the same pbroadcast to the same operand multiple times, creating common subexpressions. When we transpose, those could turn into a sum-of-psums rather than a psum-of-sum. We’ll rely on the compiler to clean that up as appropriate. If it’s a problem then we could add some simple memoization to the pbroadcast-insertion pass.

The user API for all_gather will mean all_gather_p by default (not all_gather_invariant_p), covering the common case and meaning no pbroadcasts must be inserted.

We can provide an option on shmap to disable this automatic insertion of pbroadcasts, in which case it’ll be up to the user to ensure type-correctness. This explicit option may be appealing to some who want to be explicit about where the psums occur in the backward pass.

How to implement the solution#

The key to making the implementation lightweight is that we aren’t going to add these types to avals or jaxprs. At least, not at first. That can be expensive because it requires updating the rest of JAX, e.g. all consumers of avals and jaxprs may need to handle the new types. We’re not falling for that again!

Instead we’re going to keep these extended types as metadata internal to shmap, just like the current “replication checking for out_specs” machinery is internal to shmap. Indeed this solution amounts to a relatively small extension to that existing machinery: it was already tracking the same information; now we’re just adding the pbroadcasts.

We have at least two options for where to perform the pbroadcast insertion:

  1. just before transposition, in the transpose rule, where we have a jaxpr of the computation to be transposed;

  2. in every shmap body, whether eagerly executed or staged out, like the current “replication checking for out_specs” machinery. The former may end up being easier since we only have to handle the jaxpr case, and only linear primitives. But we’ll start by trying the latter so the implementation here is a strict revision/extension to the existing replication-checking logic.

Appendix: defining and motivating maps with unmapped inputs and outputs#

For concreteness, we’ll mostly focus on shmap, though these same ideas apply to e.g. pmap and probably xmap.

An argument/input is unmapped along a mesh axis when the corresponding entry of in_specs doesn’t mention that mesh axis’s name. Logically it means that each function instance along that mesh axis gets the same value for the argument. To the caller, each operand is sliced according to the mesh axes over which the operand is mapped, whereas there is no slicing for mesh axes over which the operand is unmapped.

An output is unmapped along a mesh axis when the corresponding entry of out_specs doesn’t mention that mesh axis’s name. Logically it means each function instance along that mesh axis must return the same value. To the caller, each result of the shmap is formed by concatenating the return values of every function instance along which the outputs are mapped, whereas for mesh axes over which the output is unmapped only one copy of the value is used.

See the shmap JEP for examples of unmapped inputs and outputs. For comparison, in vmap unmapped inputs/outputs are indicated by using in_axes / out_axes of None (rather than an int).

Here are reasons we like unmapped inputs and outputs for shmap:

  • Same expressiveness as pjit. Anything pjit can do, the shmap escape hatch should be able to do too. Or else we’d have a lacking escape hatch! If we didn’t have unmapped outputs in shmap then we couldn’t express the same batch-parallel loss function computations as pjit.

  • Closed-over inputs. Closed-over inputs essentially correspond to unmapped inputs, and…

  • Closure under transposition. Once we have unmapped inputs, it’s natural to be able to transpose to unmapped outputs.

So unmapped outputs are both canonical and useful!