Efficient transposition of replicationinducing collectives#
mattjj@, dougalm@
August 2023
Motivation#
We have an efficiency problem in automatically transposing shmap
s containing
certain collectives. The issue arises with psum
and all_gather
, specifically
when the output of the collective is returned to the caller as an unmapped
output. And itâ€™s not an edge case: for example, it arises when applying grad
to a shmap
based batch data parallel neural network loss function which uses
psum
to compute the total loss.
Weâ€™ve known about this problem for some time. An analogous issue exists with
pmap
, though itâ€™s been worked around by keeping grad
inside pmap
rather than
outside. A primary goal of the incomplete avalswithnames work was to address a
version of this transpose efficiency problem. This doc draws on those ideas,
while extending and revising them to handle more cases and to be much easier to
land. Indeed the solution proposed here only affects the shmap
implementation.
The rest of the system need not be changed (yet).
The main purpose of this doc is to define this transpose efficiency problem and propose an easytoland solution.
This doc is not about:
logical axis names on arrays (the only axis names here are just like in
shmap
and OGpmap
);changing autodiff semantics (all the numbers and (non)errors are staying the same, weâ€™re just making things more efficient);
allowing user code to reflect on any new information, or really affecting user code at all.
Problem: efficient transpose of psum
or all_gather
depends on whether cotangents are invariant across devices#
Consider this semirealistic example, meant to resemble a replicatedparameter batch data parallel loss function:
devices = jax.devices() # 8 devices
@partial(shmap, mesh=Mesh(devices, ('batch',)),
in_specs=(P(None, None), P('batch', None)),
out_specs=P())
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
local_loss = jnp.mean(jnp.sum(predictions  targets, 1))
global_loss = lax.pmean(local_loss, 'batch'))
return global_loss
Notice the out_specs=P()
, which indicates an unmapped output. If youâ€™re not
familiar with the notion of unmapped outputs, see the appendix at the bottom of
this document.
Most of the details in the loss
example arenâ€™t important. All that matters for
our purposes is that weâ€™re applying psum
(or rather pmean = lambda x, name: psum(x, name) / psum(1, name)
) at the end. So a distilled version looks like
this:
# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
We even simplified notation by suppressing the mesh
argument. In the examples to
follow it can be inferred from context.
What does the transpose look like? Writing t
to mean function transpose, we
could evaluate t(f1)(ybar)
for any ybar
efficiently by applying the function
Âżf1_transpose?
below:
# An efficient "transpose" of Example 1 (but don't transpose this again!)
Âżf1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))
But thatâ€™s not the transpose we currently get as t(f1).
Instead, the current recipe for transposition is roughly that we switch
in_specs
and out_specs
, do some division rescaling for unmapped outputs, and
transpose the body. Because psum
is its own transpose (as an allreduce sum),
we end up producing this transpose:
# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
in_specs=P(), out_specs=P('i'))
This transpose gets the numbers right, but itâ€™s wasteful. We know statically
from the transposeâ€™s in_specs=P()
that ybar
has the same value for each function
instance, i.e. that its value is deviceinvariant for devices along the mesh
axis named i
, and yet we apply a psum
to it! That uses expensive communication
just to multiply the value on each device by 8. (Here 8 refers to the size of
axis i. The division by 8 comes from the original functionâ€™s out_specs=P()
; it
and the trivial psum
basically cancel each other out.)
What are we doing wrong? Weâ€™re not exploiting the fact that cotangents ybar
corresponding to f1
â€™s unmapped outputs are guaranteed to be deviceinvariant;
instead, weâ€™re defensively psum
ming them as if they werenâ€™t because psum
â€™s
transpose canâ€™t be sure given the local information it has. Sometimes the psum
is necessary, as in transposing f2
with respect to its first argument:
# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
Intuitively, if our transpose machinery could tell the difference between Example 1 and Example 2, we could do better by avoiding the psum and division where possible.
The inefficient examples can be even smaller. Consider transposing this cursed identity function:
# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())
# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...
It keeps getting bigger the more we transpose. How embarrassing!
And psum
isnâ€™t the only culprit. Something analogous holds true for
all_gather
:
# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))
This program is a bit artificial. Why do an all_gather
and feed the result into
an unmapped output, rather than skipping the all_gather
in the body and just
using out_specs=P('i')
to collect the results? But even though itâ€™s cookedup,
this example nevertheless exhibits a transpose which unnecessarily performs
communication (we could have just performed a noncommunicating slice),
analogous to Example 1 for psum
.
Also analogously to the psum
examples, the defensive psum_scatter
is
necessary in some cases:
# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
in_specs=(P('i'), P('i')), out_specs=P('i'))
So how do we avoid these inefficient transposes?
Solutions#
Here are two solution ideas. They arenâ€™t mutually exclusive. But (spoilers) the second one is better, and itâ€™s all we need.
Partial solution â€śPsumâ€ť: build the ability to express a psum
into out_specs
#
This solution is a bit of a strawperson because it would offer only an awkward way to write programs. And it wouldnâ€™t even fix everything! But itâ€™s worth considering, if only to motivate a more complete solution.
Example 4 above is artificial because we could have just used out_specs
instead
of an all_gather
in the body:
# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))
The f4_better
version doesnâ€™t have any transposition problems, since the
transpose problems arise from collectives in the body.
Analogously, we could fix Example 1 by extending out_specs
so that they can
express summing:
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i')) # sum='i' means sum over that axis
# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))
So offering psum
s built into out_specs
fixes the transpose problem of
Example 1. But it doesnâ€™t fully fix the cursed identity transpose in Example 3:
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# How it would transpose with the Psum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))
Itâ€™s an improvement since the program doesnâ€™t continue to get bigger as we keep transposing, but weâ€™re still doing wasteful communication.
Full solution: statically track devicevarying vs deviceinvariant intermediates, plus new primitives#
This solution has two components:
track when values are guaranteed to be deviceinvariant vs devicevarying over particular mesh axes, and
decompose
psum
into a twostep process, introducing a newpbroadcast
primitive, and introduce new primitives forall_gather
and its transposes.
Morally, the tracking of deviceinvariant vs devicevarying information is a typelevel consideration. But for the expedience of our first implementation, we donâ€™t need to literally add the information to abstract values or jaxpr types. Before we get to implementation, weâ€™ll first introduce the idea using types.
Also to follow is a discussion of making the user API convenient and backward compatible. But to first introduce the idea, weâ€™ll ignore convenience and instead write code that is as explicit as possible.
Tracking device invariance in avals (a.k.a. avalswithnames, revived)#
We can sometimes tell from static information alone that the values of some
intermediate variables in the body of a shmap
are guaranteed to be invariant
along a mesh axis, in the sense that the function instances (and their
corresponding devices) along the mesh axis must all be computing with the same
value. Weâ€™ll call such values deviceinvariant. For values that are not
deviceinvariant, weâ€™ll say theyâ€™re devicevarying, though really we mean
potentially devicevarying from the point of view of the type system.
To encode device variance in types, weâ€™ll extend the syntax of types for arrays.
Weâ€™ll write things like x:f32[3,4]{i}
to indicate that x
is (potentially)
devicevarying along mesh axis i
(and deviceinvariant over any other mesh
axes of the shmap
). More generally, weâ€™ll say the grammar for array type
syntax is something like
shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}
Weâ€™ll also update the typing rules to handle device variance types:
for firstorder primitives other than collectives
for multiarity primitives, the operand device variance types must be equal where shapes must be equal, e.g.
mul x:f32[s1]{r1} y:f32[s2][r2]
requiresr1 == r2
in addition tos1 == s2
the output device variance type must be the same as the operand(s)
for higherorder primitives
we just instantiate any type variables including the device variance type (and checking types for equality checks their device variance types are equal)
(when performing type inference, e.g. for branches of a
cond
, we take the union of the sets of axis names in device variance types)
for firstorder collectives
a collective can either accept a devicevarying or deviceinvariant input (along a mesh axis corresponding to its axis name parameter); itâ€™s an error to pass a deviceinvariant operand to a collective which accepts devicevarying operands and viceversa
a collective can either produce a devicevarying or deviceinvariant output
see the table below As a side benefit, whatever logic implements this type checking can subsume
shmap
â€™s â€śstatic analysisâ€ť check for whether ashmap
body function is compatible with any unmappedout_specs
.
Hereâ€™s a table summarizing the device variance typing for collective primitives:
Name 
Device variance type 
Example 
Lowers to HLO 
Transpose 









noop (no communication) 










n/a 




















There are some surprising things here!
We introduced several new primitives, including
pbroadcast
, which interestingly lowers to a noopall_gather_invariant
, which lowers to the same thing asall_gather
but has a different device variance type (essentiallyall_gather
has apbroadcast
fused into it, whereasall_gather_invariant
does not)pscatter
which is the dual (transpose) ofall_gather_invariant
all_gather has a devicevarying result
Intuitively, the reason to introduce pbroadcast
(other than to make the typing
rules work) is so that psum
can transpose to a physical noop. The reason we
need all_gather
to have a devicevarying result is so that we can transpose it
to psum_scatter
; if we instead left it with a deviceinvariant result, we
might need a downstream pbroadcast
, and that composition would transpose to an
inefficient psum
followed by slicing / pscatter
. So instead we have a
pbroadcast
â€śfused intoâ€ť the all_gather
, thus allowing for an efficient
transpose into psum_scatter
. We provide all_gather_invariant
and its
transpose pscatter
mainly for completeness; itâ€™s unlikely users will need it
(it corresponds to the situation in Example 4, which is easy to write
differently using out_specs
).
Interestingly, the psum
and pbroadcast
transpose pair correspond to the
psum_idrev
and id_psumrev
that users introduced while training LLMs with
pmap
.
How this system solves the inefficient transpose examples#
Consider again the simplified motivating example:
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
w:f32[]{i} = g(x)
y:f32[]{} = psum(w, 'i')
return y
With these new rules, the transpose is:
# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
in_specs=P(), out_specs=P('i'))
# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
wbar:f32[]{i} = pbroadcast(ybar, 'i')
xbar:f32[3,4]{i} = transpose(g)(wbar)
return xbar
where evaluating the pbroadcast
application involves no communication or FLOPs
at all; itâ€™s a noop. Notice that if we keep transposing the body does not grow
in size; indeed t(t(f1)) == f1
. Efficiency achieved!
And we wouldnâ€™t mess up the other examples either, so long as we pbroadcast
to
make the types check where needed:
# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] > f32[...]`, i.e. no device varying type.
# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())
Intuitively, in Example 1 we now only have â€śhalf the original psumâ€ť, whereas in Example 2 we get both â€śhalvesâ€ť. For Example 3 we never need any operations in the body at all.
For the all_gather
examples, Example 4 would need to use
all_reduce_invariant
to have an efficient transpose (though itâ€™d be better to
instead use out_specs
instead of the collective in the body):
# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())
# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
y:f32[8]{} = all_gather_invariant(x, 'i')
return y
# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
xbar:f32[1]{i} = pscatter(ybar, 'i')
return xbar
For Example 5, using the devicevarying all_gather
works as weâ€™d want:
# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
z:f32[8]{i} = all_gather(x, 'i')
w:f32[8]{i} = z * y
return w
# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
zbar:f32[8]{i} = wbar * y
xbar:f32[1]{i} = psum_scatter(zbar, 'i')
return xbar
How to make the API convenient for users (and backward compatible)#
But what user wants to write pbroadcast
s? And what developer wants to break
lots of existing user code involving psum
s which are not fed into unmapped
outputs? Not me!
Instead we can automatically insert the pbroadcast
s. Itâ€™s a bit analogous to how
we do automatic rank promotion at the jax.numpy
layer, inserting broadcasts to
avoid rank mismatch errors in binary operators. But itâ€™s much simpler since we
donâ€™t need to contend with shape tuples. The typical rule is: whenever we see a
multiarity operation where the operands disagree in their device variance
types, take the union of operandsâ€™ device variance typesâ€™ axis name sets and
insert pbroadcast
s to lift each operand to the resulting device variance type.
Automatically inserting pbroadcast
s just before theyâ€™re needed may mean we apply
the same pbroadcast
to the same operand multiple times, creating common
subexpressions. When we transpose, those could turn into a sumofpsum
s rather
than a psum
ofsum. Weâ€™ll rely on the compiler to clean that up as appropriate.
If itâ€™s a problem then we could add some simple memoization to the
pbroadcast
insertion pass.
The user API for all_gather
will mean all_gather_p
by default (not
all_gather_invariant_p
), covering the common case and meaning no pbroadcast
s
must be inserted.
We can provide an option on shmap
to disable this automatic insertion of
pbroadcast
s, in which case itâ€™ll be up to the user to ensure typecorrectness.
This explicit option may be appealing to some who want to be explicit about
where the psum
s occur in the backward pass.
How to implement the solution#
The key to making the implementation lightweight is that we arenâ€™t going to add these types to avals or jaxprs. At least, not at first. That can be expensive because it requires updating the rest of JAX, e.g. all consumers of avals and jaxprs may need to handle the new types. Weâ€™re not falling for that again!
Instead weâ€™re going to keep these extended types as metadata internal to
shmap
, just like the current â€śreplication checking for out_specs
â€ť machinery
is internal to shmap
. Indeed this solution amounts to a relatively small
extension to that existing machinery: it was already tracking the same
information; now weâ€™re just adding the pbroadcast
s.
We have at least two options for where to perform the pbroadcast
insertion:
just before transposition, in the transpose rule, where we have a jaxpr of the computation to be transposed;
in every
shmap
body, whether eagerly executed or staged out, like the current â€śreplication checking forout_specs
â€ť machinery. The former may end up being easier since we only have to handle the jaxpr case, and only linear primitives. But weâ€™ll start by trying the latter so the implementation here is a strict revision/extension to the existing replicationchecking logic.
Appendix: defining and motivating maps with unmapped inputs and outputs#
For concreteness, weâ€™ll mostly focus on shmap
, though these same ideas apply
to e.g. pmap
and probably xmap
.
An argument/input is unmapped along a mesh axis when the corresponding entry
of in_specs
doesnâ€™t mention that mesh axisâ€™s name. Logically it means that
each function instance along that mesh axis gets the same value for the
argument. To the caller, each operand is sliced according to the mesh axes over
which the operand is mapped, whereas there is no slicing for mesh axes over
which the operand is unmapped.
An output is unmapped along a mesh axis when the corresponding entry of
out_specs
doesnâ€™t mention that mesh axisâ€™s name. Logically it means each
function instance along that mesh axis must return the same value. To the
caller, each result of the shmap
is formed by concatenating the return values
of every function instance along which the outputs are mapped, whereas for mesh
axes over which the output is unmapped only one copy of the value is used.
See the shmap
JEP for examples
of unmapped inputs and outputs. For comparison, in vmap
unmapped
inputs/outputs are indicated by using in_axes
/ out_axes
of None
(rather
than an int
).
Here are reasons we like unmapped inputs and outputs for shmap
:
Same expressiveness as
pjit
. Anythingpjit
can do, theshmap
escape hatch should be able to do too. Or else weâ€™d have a lacking escape hatch! If we didnâ€™t have unmapped outputs inshmap
then we couldnâ€™t express the same batchparallel loss function computations aspjit
.Closedover inputs. Closedover inputs essentially correspond to unmapped inputs, andâ€¦
Closure under transposition. Once we have unmapped inputs, itâ€™s natural to be able to transpose to unmapped outputs.
So unmapped outputs are both canonical and useful!