Introduction to parallel programming#
This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.
The tutorial covers three modes of parallel computation:
Automatic parallelism via
jax.jit()
: The compiler chooses the optimal computation strategy (a.k.a. “the compiler takes the wheel”).Semi-automated parallelism using
jax.jit()
andjax.lax.with_sharding_constraint()
Fully manual parallelism with manual control using
jax.experimental.shard_map.shard_map()
:shard_map
enables per-device code and explicit communication collectives
Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices.
If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: Runtime > Change runtime type > Hardware accelerator > TPU v2 (which provides eight devices to work with).
import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Key concept: Data sharding#
Key to all of the distributed computation approaches below is the concept of data sharding, which describes how data is laid out on the available devices.
How can JAX understand how the data is laid out across devices? JAX’s datatype, the jax.Array
immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The jax.Array
object is designed with distributed data and computation in mind. Every jax.Array
has an associated jax.sharding.Sharding
object, which describes which shard of the global data is required by each global device. When you create a jax.Array
from scratch, you also need to create its Sharding
.
In the simplest cases, arrays are sharded on a single device, as demonstrated below:
import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()
{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}
arr.sharding
SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0))
For a more visual representation of the storage layout, the jax.debug
module provides some helpers to visualize the sharding of an array. For example, jax.debug.visualize_array_sharding()
displays how the array is stored in memory of a single device:
jax.debug.visualize_array_sharding(arr)
TPU 0
To create an array with a non-trivial sharding, you can define a jax.sharding
specification for the array and pass this to jax.device_put()
.
Here, define a NamedSharding
, which specifies an N-dimensional grid of devices with named axes, where jax.sharding.Mesh
allows for precise device placement:
from jax.sharding import PartitionSpec as P
mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'))
Passing this Sharding
object to jax.device_put()
, you can obtain a sharded array:
arr_sharded = jax.device_put(arr, sharding)
print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
[[ 0. 1. 2. 3. 4. 5. 6. 7.]
[ 8. 9. 10. 11. 12. 13. 14. 15.]
[16. 17. 18. 19. 20. 21. 22. 23.]
[24. 25. 26. 27. 28. 29. 30. 31.]]
TPU 0 TPU 1 TPU 2 TPU 3 TPU 6 TPU 7 TPU 4 TPU 5
The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.
1. Automatic parallelism via jit
#
Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a jax.jit()
-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.
The XLA compiler behind jit
includes heuristics for optimizing computations across multiple devices.
In the simplest of cases, those heuristics boil down to computation follows data.
To demonstrate how auto-parallelization works in JAX, below is an example that uses a jax.jit()
-decorated staged-out function: it’s a simple element-wise function, where the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:
@jax.jit
def f_elementwise(x):
return 2 * jnp.sin(x) + 1
result = f_elementwise(arr_sharded)
print("shardings match:", result.sharding == arr_sharded.sharding)
shardings match: True
As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data.
Here, you sum along the leading axis of x
, and visualize how the result values are stored across multiple devices (with jax.debug.visualize_array_sharding()
):
@jax.jit
def f_contract(x):
return x.sum(axis=0)
result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
TPU 0,6 TPU 1,7 TPU 2,4 TPU 3,5
[48. 52. 56. 60. 64. 68. 72. 76.]
The result is partially replicated: that is, the first two elements of the array are replicated on devices 0
and 6
, the second on 1
and 7
, and so on.
2. Semi-automated sharding with constraints#
If you’d like to have some control over the sharding used within a particular computation, JAX offers the with_sharding_constraint()
function. You can use jax.lax.with_sharding_constraint()
(in place of jax.device_put()
) together with jax.jit()
for more control over how the compiler constraints how the intermediate values and outputs are distributed.
For example, suppose that within f_contract
above, you’d prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:
@jax.jit
def f_contract_2(x):
out = x.sum(axis=0)
mesh = jax.make_mesh((8,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
return jax.lax.with_sharding_constraint(out, sharding)
result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)
TPU 0 TPU 1 TPU 2 TPU 3 TPU 6 TPU 7 TPU 4 TPU 5
[48. 52. 56. 60. 64. 68. 72. 76.]
This gives you a function with the particular output sharding you’d like.
3. Manual parallelism with shard_map
#
In the automatic parallelism methods explored above, you can write a function as if you’re operating on the full dataset, and jit
will split that computation across multiple devices. By contrast, with jax.experimental.shard_map.shard_map()
you write the function that will handle a single shard of data, and shard_map
will construct the full function.
shard_map
works by mapping a function across a particular mesh of devices (shard_map
maps over shards). In the example below:
As before,
jax.sharding.Mesh
allows for precise device placement, with the axis names parameter for logical and physical axis names.The
in_specs
argument determines the shard sizes. Theout_specs
argument identifies how the blocks are assembled back together.
Note: jax.experimental.shard_map.shard_map()
code can work inside jax.jit()
if you need it.
from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((8,), ('x',))
f_elementwise_sharded = shard_map(
f_elementwise,
mesh=mesh,
in_specs=P('x'),
out_specs=P('x'))
arr = jnp.arange(32)
f_elementwise_sharded(arr)
Array([ 1. , 2.682942 , 2.818595 , 1.28224 , -0.513605 ,
-0.9178486 , 0.44116896, 2.3139732 , 2.9787164 , 1.824237 ,
-0.08804226, -0.99998045, -0.07314599, 1.8403342 , 2.9812148 ,
2.3005757 , 0.42419332, -0.92279506, -0.50197446, 1.2997544 ,
2.8258905 , 2.6733112 , 0.98229736, -0.69244075, -0.81115675,
0.7352965 , 2.525117 , 2.912752 , 1.5418116 , -0.32726777,
-0.97606325, 0.19192469], dtype=float32)
The function you write only “sees” a single batch of the data, which you can check by printing the device local shape:
x = jnp.arange(32)
print(f"global shape: {x.shape=}")
def f(x):
print(f"device local shape: {x.shape=}")
return x * 2
y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
global shape: x.shape=(32,)
device local shape: x.shape=(4,)
Because each of your functions only “sees” the device-local part of the data, it means that aggregation-like functions require some extra thought.
For example, here’s what a shard_map
of a jax.numpy.sum()
looks like:
def f(x):
return jnp.sum(x, keepdims=True)
shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
Array([ 6, 22, 38, 54, 70, 86, 102, 118], dtype=int32)
Your function f
operates separately on each shard, and the resulting summation reflects this.
If you want to sum across shards, you need to explicitly request it using collective operations like jax.lax.psum()
:
def f(x):
sum_in_shard = x.sum()
return jax.lax.psum(sum_in_shard, 'x')
shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)
Array(496, dtype=int32)
Because the output no longer has a sharded dimension, set out_specs=P()
(recall that the out_specs
argument identifies how the blocks are assembled back together in shard_map
).
Comparing the three approaches#
With these concepts fresh in our mind, let’s compare the three approaches for a simple neural network layer.
Start by defining your canonical function like this:
@jax.jit
def layer(x, weights, bias):
return jax.nn.sigmoid(x @ weights + bias)
import numpy as np
rng = np.random.default_rng(0)
x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))
layer(x, weights, bias)
Array([0.02138912, 0.893112 , 0.59892005, 0.97742504], dtype=float32)
You can automatically run this in a distributed manner using jax.jit()
and passing appropriately sharded data.
If you shard the leading axis of both x
and weights
in the same way, then the matrix multiplication will automatically happen in parallel:
mesh = jax.make_mesh((8,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)
layer(x_sharded, weights_sharded, bias)
Array([0.02138912, 0.893112 , 0.59892005, 0.97742504], dtype=float32)
Alternatively, you can use jax.lax.with_sharding_constraint()
in the function to automatically distribute unsharded inputs:
@jax.jit
def layer_auto(x, weights, bias):
x = jax.lax.with_sharding_constraint(x, sharding)
weights = jax.lax.with_sharding_constraint(weights, sharding)
return layer(x, weights, bias)
layer_auto(x, weights, bias) # pass in unsharded inputs
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)
Finally, you can do the same thing with shard_map
, using jax.lax.psum()
to indicate the cross-shard collective required for the matrix product:
from functools import partial
@jax.jit
@partial(shard_map, mesh=mesh,
in_specs=(P('x'), P('x', None), P(None)),
out_specs=P(None))
def layer_sharded(x, weights, bias):
return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)
layer_sharded(x, weights, bias)
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32)
Next steps#
This tutorial serves as a brief introduction of sharded and parallel computation in JAX.
To learn about each SPMD method in-depth, check out these docs: