Distributed arrays and automatic parallelization#

Open in Colab Open in Kaggle

This tutorial discusses parallelism via jax.Array, the unified array object model available in JAX v0.4.1 and newer.

import os

import functools
from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp

⚠️ WARNING: The notebook requires 8 devices to run.

if len(jax.local_devices()) < 8:
  raise Exception("Notebook requires 8 devices to run")

Intro and a quick example#

By reading this tutorial notebook, you’ll learn about jax.Array, a unified datatype for representing arrays, even with physical storage spanning multiple devices. You’ll also learn about how using jax.Arrays together with jax.jit can provide automatic compiler-based parallelization.

Before we think step by step, here’s a quick example. First, we’ll create a jax.Array sharded across multiple devices:

from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
# Create a Sharding object to distribute a value across devices:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, sharding.reshape(4, 2))
jax.debug.visualize_array_sharding(y)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  TPU 0   β”‚  TPU 1   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 2   β”‚  TPU 3   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 6   β”‚  TPU 7   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 4   β”‚  TPU 5   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Next, we’ll apply a computation to it and visualize how the result values are stored across multiple devices too:

z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  TPU 0   β”‚  TPU 1   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 2   β”‚  TPU 3   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 6   β”‚  TPU 7   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 4   β”‚  TPU 5   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

The evaluation of the jnp.sin application was automatically parallelized across the devices on which the input values (and output values) are stored:

# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
The slowest run took 13.32 times longer than the fastest. This could mean that an intermediate result is being cached 
5 loops, best of 5: 9.69 ms per loop
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()
5 loops, best of 5: 1.86 ms per loop

Now let’s look at each of these pieces in more detail!

Sharding describes how array values are laid out in memory across devices#

Sharding basics, and the PositionalSharding subclass#

To parallelize computation across multiple devices, we first must lay out input data across multiple devices.

In JAX, Sharding objects describe distributed memory layouts. They can be used with jax.device_put to produce a value with distributed layout.

For example, here’s a value with a single-device Sharding:

import jax
x = jax.random.normal(jax.random.key(0), (8192, 8192))
jax.debug.visualize_array_sharding(x)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                       β”‚
β”‚                       β”‚
β”‚                       β”‚
β”‚                       β”‚
β”‚         TPU 0         β”‚
β”‚                       β”‚
β”‚                       β”‚
β”‚                       β”‚
β”‚                       β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Here, we’re using the jax.debug.visualize_array_sharding function to show where the value x is stored in memory. All of x is stored on a single device, so the visualization is pretty boring!

But we can shard x across multiple devices by using jax.device_put and a Sharding object. First, we make a numpy.ndarray of Devices using mesh_utils.create_device_mesh, which takes hardware topology into account for the Device order:

from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,))

Then, we create a PositionalSharding and use it with device_put:

from jax.sharding import PositionalSharding

sharding = PositionalSharding(devices)

x = jax.device_put(x, sharding.reshape(8, 1))
jax.debug.visualize_array_sharding(x)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚         TPU 0         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         TPU 1         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         TPU 2         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         TPU 3         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         TPU 6         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         TPU 7         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         TPU 4         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         TPU 5         β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Here sharding is a PositionalSharding which acts like an array with sets of devices as elements:

sharding
PositionalSharding([{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}])

By writing PositionalSharding(ndarray_of_devices), we fix the device order and the initial shape. Then we can reshape it:

sharding.reshape(8, 1)
PositionalSharding([[{TPU 0}]
                    [{TPU 1}]
                    [{TPU 2}]
                    [{TPU 3}]
                    [{TPU 6}]
                    [{TPU 7}]
                    [{TPU 4}]
                    [{TPU 5}]])
sharding.reshape(4, 2)
PositionalSharding([[{TPU 0} {TPU 1}]
                    [{TPU 2} {TPU 3}]
                    [{TPU 6} {TPU 7}]
                    [{TPU 4} {TPU 5}]])

To use device_put with a data array x, we can reshape the sharding into a shape that is congruent with x.shape, meaning a shape with the same length as x.shape and where each element evenly divides the corresponding element of x.shape:

def is_congruent(x_shape: Sequence[int], sharding_shape: Sequence[int]) -> bool:
  return (len(x_shape) == len(sharding_shape) and
          all(d1 % d2 == 0 for d1, d2 in zip(x_shape, sharding_shape)))

For example, we can reshape sharding to have shape (4, 2), then use it in a device_put:

sharding = sharding.reshape(4, 2)
print(sharding)
PositionalSharding([[{TPU 0} {TPU 1}]
                    [{TPU 2} {TPU 3}]
                    [{TPU 6} {TPU 7}]
                    [{TPU 4} {TPU 5}]])
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  TPU 0   β”‚  TPU 1   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 2   β”‚  TPU 3   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 6   β”‚  TPU 7   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 4   β”‚  TPU 5   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Here y represents the same value as x, but its shards (i.e. slices) are stored in different devices’ memories.

Different PositionalSharding shapes result in different distributed layouts (i.e. shardings) of the result:

sharding = sharding.reshape(1, 8)
print(sharding)
PositionalSharding([[{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}]])
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”
β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚
β”‚ TPU 0 β”‚ TPU 1 β”‚ TPU 2 β”‚ TPU 3 β”‚ TPU 6 β”‚ TPU 7 β”‚ TPU 4 β”‚ TPU 5 β”‚
β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜

In some cases, we don’t just want to store each slice of x in a single device’s memory; we might want to replicate some slices, meaning storing copies of a slice’s values in multiple devices’ memories.

With PositionalSharding, we can express replication by calling the reducer method replicate:

sharding = sharding.reshape(4, 2)
print(sharding.replicate(axis=0, keepdims=True))
PositionalSharding([[{TPU 0, 2, 4, 6} {TPU 1, 3, 5, 7}]])
y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True))
jax.debug.visualize_array_sharding(y)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚TPU 0,2,4,6β”‚TPU 1,3,5,7β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Here the visualization shows that x is sharded two ways along its second dimension (and not sharded along the first dimension), and each of those shards is replicated four ways (i.e. stored in four device memories).

The replicate method is analogous to the familiar NumPy array reduction methods like .sum() and .prod(). It operates along an axis performing a set union. So if sharding has shape (4, 2), then sharding.replicate(0, keepdims=True) has shape (1, 2), and sharding.replicate(1, keepdims=True) has shape (4, 1). Unlike analogous NumPy methods, keepdims=True is actually the default, so reduced-over axes aren’t squeezed:

print(sharding.replicate(0).shape)
print(sharding.replicate(1).shape)
(1, 2)
(4, 1)
y = jax.device_put(x, sharding.replicate(1))
jax.debug.visualize_array_sharding(y)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚        TPU 0,1        β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚        TPU 2,3        β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚        TPU 6,7        β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚        TPU 4,5        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

NamedSharding gives a way to express shardings with names#

So far we’ve worked with PositionalSharding, but there are alternative ways to express shardings. In fact, Sharding is an interface, and any class that implements that interface can be used with functions like device_put.

Another convenient way to express sharding is with the NamedSharding:

from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils

P = PartitionSpec

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  TPU 0   β”‚  TPU 1   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 2   β”‚  TPU 3   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 6   β”‚  TPU 7   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 4   β”‚  TPU 5   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

We can define a helper function to make things simpler:

devices = mesh_utils.create_device_mesh((4, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))

def mesh_sharding(
    pspec: PartitionSpec, mesh: Optional[Mesh] = None,
  ) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)
y = jax.device_put(x, mesh_sharding(P('a', 'b')))
jax.debug.visualize_array_sharding(y)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  TPU 0   β”‚  TPU 1   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 2   β”‚  TPU 3   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 6   β”‚  TPU 7   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 4   β”‚  TPU 5   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Here, we use P('a', 'b') to express that the first and second axes of x should be sharded over the device mesh axes 'a' and 'b', respectively. We can easily switch to P('b', 'a') to shard the axes of x over different devices:

y = jax.device_put(x, mesh_sharding(P('b', 'a')))
jax.debug.visualize_array_sharding(y)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚ TPU 0 β”‚ TPU 2 β”‚ TPU 6 β”‚ TPU 4 β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚ TPU 1 β”‚ TPU 3 β”‚ TPU 7 β”‚ TPU 5 β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜
# This `None` means that `x` is not sharded on its second dimension,
# and since the Mesh axis name 'b' is not mentioned, shards are
# replicated across it.
y = jax.device_put(x, mesh_sharding(P('a', None)))
jax.debug.visualize_array_sharding(y)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚        TPU 0,1        β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚        TPU 2,3        β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚        TPU 6,7        β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚        TPU 4,5        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Here, because P('a', None) doesn’t mention the Mesh axis name 'b', we get replication over the axis 'b'. The None here is just acting as a placeholder to line up against the second axis of the value x, without expressing sharding over any mesh axis. (As a shorthand, trailing Nones can be omitted, so that P('a', None) means the same thing as P('a'). But it doesn’t hurt to be explicit!)

To shard only over the second axis of x, we can use a None placeholder in the PartitionSpec:

y = jax.device_put(x, mesh_sharding(P(None, 'b')))
jax.debug.visualize_array_sharding(y)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚TPU 0,2,4,6β”‚TPU 1,3,5,7β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
y = jax.device_put(x, mesh_sharding(P(None, 'a')))
jax.debug.visualize_array_sharding(y)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚TPU 0,1β”‚TPU 2,3β”‚TPU 6,7β”‚TPU 4,5β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜

For a fixed mesh, we can even partition one logical axis of x over multiple device mesh axes:

y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚         TPU 0         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         TPU 1         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         TPU 2         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         TPU 3         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         TPU 6         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         TPU 7         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         TPU 4         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         TPU 5         β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Using NamedSharding makes it easy to define a device mesh once and give its axes names, then just refer to those names in PartitionSpecs for each device_put as needed.

Computation follows data sharding and is automatically parallelized#

With sharded input data, the compiler can give us parallel computation. In particular, functions decorated with jax.jit can operate over sharded arrays without copying data onto a single device. Instead, computation follows sharding: based on the sharding of the input data, the compiler decides shardings for intermediates and output values, and parallelizes their evaluation, even inserting communication operations as necessary.

For example, the simplest computation is an elementwise one:

from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
x = jax.device_put(x, sharding.reshape(4, 2))
print('input sharding:')
jax.debug.visualize_array_sharding(x)

y = jnp.sin(x)
print('output sharding:')
jax.debug.visualize_array_sharding(y)
input sharding:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  TPU 0   β”‚  TPU 1   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 2   β”‚  TPU 3   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 6   β”‚  TPU 7   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 4   β”‚  TPU 5   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
output sharding:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  TPU 0   β”‚  TPU 1   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 2   β”‚  TPU 3   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 6   β”‚  TPU 7   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 4   β”‚  TPU 5   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Here for the elementwise operation jnp.sin the compiler chose the output sharding to be the same as the input. Moreover, the compiler automatically parallelized the computation, so that each device computed its output shard from its input shard in parallel.

In other words, even though we wrote the jnp.sin computation as if a single machine were to execute it, the compiler splits up the computation for us and executes it on multiple devices.

We can do the same for more than just elementwise operations too. Consider a matrix multiplication with sharded inputs:

y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))
z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)

w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w)
lhs sharding:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚        TPU 0,1        β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚        TPU 2,3        β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚        TPU 6,7        β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚        TPU 4,5        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
rhs sharding:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚TPU 0,2,4,6β”‚TPU 1,3,5,7β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
out sharding:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  TPU 0   β”‚  TPU 1   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 2   β”‚  TPU 3   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 6   β”‚  TPU 7   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 4   β”‚  TPU 5   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Here the compiler chose the output sharding so that it could maximally parallelize the computation: without needing communication, each device already has the input shards it needs to compute its output shard.

How can we be sure it’s actually running in parallel? We can do a simple timing experiment:

x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                       β”‚
β”‚                       β”‚
β”‚                       β”‚
β”‚                       β”‚
β”‚         TPU 0         β”‚
β”‚                       β”‚
β”‚                       β”‚
β”‚                       β”‚
β”‚                       β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
np.allclose(jnp.dot(x_single, x_single),
            jnp.dot(y, z))
True
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()
5 loops, best of 5: 19.3 ms per loop
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()
5 loops, best of 5: 3.25 ms per loop

Even copying a sharded Array produces a result with the sharding of the input:

w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  TPU 0   β”‚  TPU 1   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 2   β”‚  TPU 3   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 6   β”‚  TPU 7   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 4   β”‚  TPU 5   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

So computation follows data placement: when we explicitly shard data with jax.device_put, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of JAX’s policy of following explicit device placement.

When explicit shardings disagree, JAX errors#

But what if two arguments to a computation are explicitly placed on different sets of devices, or with incompatible device orders? In these ambiguous cases, an error is raised:

import textwrap
from termcolor import colored

def print_exception(e):
  name = colored(f'{type(e).__name__}', 'red')
  print(textwrap.fill(f'{name}: {str(e)}'))
sharding1 = PositionalSharding(jax.devices()[:4])
sharding2 = PositionalSharding(jax.devices()[4:])

y = jax.device_put(x, sharding1.reshape(2, 2))
z = jax.device_put(x, sharding2.reshape(2, 2))
try: y + z
except ValueError as e: print_exception(e)
ValueError: Devices of all `Array` inputs and outputs should
be the same. Got array device ids [0, 1, 2, 3] on platform TPU and
another array's device ids [4, 5, 6, 7] on platform TPU
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]

sharding1 = PositionalSharding(devices)
sharding2 = PositionalSharding(permuted_devices)

y = jax.device_put(x, sharding1.reshape(4, 2))
z = jax.device_put(x, sharding2.reshape(4, 2))
try: y + z
except ValueError as e: print_exception(e)
ValueError: Devices of all `Array` inputs and outputs should
be the same. Got array device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform
TPU and another array's device ids [0, 1, 2, 3, 6, 7, 4, 5] on
platform TPU

We say arrays that have been explicitly placed or sharded with jax.device_put are committed to their device(s), and so won’t be automatically moved. See the device placement FAQ for more information.

When arrays are not explicitly placed or sharded with jax.device_put, they are placed uncommitted on the default device. Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices.

For example, the output of jnp.zeros, jnp.arange, and jnp.array are uncommitted:

y = jax.device_put(x, sharding1.reshape(4, 2))
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
print('no error!')
no error!

Constraining shardings of intermediates in jitted code#

While the compiler will attempt to decide how a function’s intermediate values and outputs should be sharded, we can also give it hints using jax.lax.with_sharding_constraint. Using jax.lax.with_sharding_constraint is much like jax.device_put, except we use it inside staged-out (i.e. jit-decorated) functions:

sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 2))
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))
  return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  TPU 0   β”‚  TPU 1   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 2   β”‚  TPU 3   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 6   β”‚  TPU 7   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 4   β”‚  TPU 5   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”Œβ”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚ TPU 0 β”‚ TPU 1 β”‚ TPU 2 β”‚ TPU 3 β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚ TPU 6 β”‚ TPU 7 β”‚ TPU 4 β”‚ TPU 5 β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β”‚       β”‚       β”‚       β”‚       β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, sharding.replicate())
  return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  TPU 0   β”‚  TPU 1   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 2   β”‚  TPU 3   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 6   β”‚  TPU 7   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  TPU 4   β”‚  TPU 5   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                       β”‚
β”‚                       β”‚
β”‚                       β”‚
β”‚                       β”‚
β”‚  TPU 0,1,2,3,4,5,6,7  β”‚
β”‚                       β”‚
β”‚                       β”‚
β”‚                       β”‚
β”‚                       β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

By adding with_sharding_constraint, we’ve constrained the sharding of the output. In addition to respecting the annotation on a particular intermediate, the compiler will use annotations to decide shardings for other values.

It’s often a good practice to annotate the outputs of computations, for example based on how the values are ultimately consumed.

Examples: neural networks#

⚠️ WARNING: The following is meant to be a simple demonstration of automatic sharding propagation with jax.Array, but it may not reflect best practices for real examples. For instance, real examples may require more use of with_sharding_constraint.

We can use jax.device_put and jax.jit’s computation-follows-sharding features to parallelize computation in neural networks. Here are some simple examples, based on this basic neural network:

import jax
import jax.numpy as jnp
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.maximum(outputs, 0)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))
def init_layer(key, n_in, n_out):
    k1, k2 = jax.random.split(key)
    W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
    b = jax.random.normal(k2, (n_out,))
    return W, b

def init_model(key, layer_sizes, batch_size):
    key, *keys = jax.random.split(key, len(layer_sizes))
    params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

    key, *keys = jax.random.split(key, 3)
    inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
    targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

    return params, (inputs, targets)

layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192

params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)

8-way batch data parallelism#

sharding = PositionalSharding(jax.devices()).reshape(8, 1)
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, sharding.replicate())
loss_jit(params, batch)
Array(23.469475, dtype=float32)
step_size = 1e-5

for _ in range(30):
  grads = gradfun(params, batch)
  params = [(W - step_size * dW, b - step_size * db)
            for (W, b), (dW, db) in zip(params, grads)]

print(loss_jit(params, batch))
10.760101
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()
5 loops, best of 5: 26.3 ms per loop
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()
5 loops, best of 5: 122 ms per loop

4-way batch data parallelism and 2-way model tensor parallelism#

sharding = sharding.reshape(4, 2)
batch = jax.device_put(batch, sharding.replicate(1))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])
β”Œβ”€β”€β”€β”€β”€β”€β”€β”
β”‚TPU 0,1β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€
β”‚TPU 2,3β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€
β”‚TPU 4,5β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€
β”‚TPU 6,7β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”˜
β”Œβ”€β”€β”€β”€β”€β”€β”€β”
β”‚TPU 0,1β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€
β”‚TPU 2,3β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€
β”‚TPU 4,5β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€
β”‚TPU 6,7β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”˜
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params

W1 = jax.device_put(W1, sharding.replicate())
b1 = jax.device_put(b1, sharding.replicate())

W2 = jax.device_put(W2, sharding.replicate(0))
b2 = jax.device_put(b2, sharding.replicate(0))

W3 = jax.device_put(W3, sharding.replicate(0).T)
b3 = jax.device_put(b3, sharding.replicate())

W4 = jax.device_put(W4, sharding.replicate())
b4 = jax.device_put(b4, sharding.replicate())

params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
jax.debug.visualize_array_sharding(W2)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚TPU 0,2,4,6β”‚TPU 1,3,5,7β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
jax.debug.visualize_array_sharding(W3)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                       β”‚
β”‚      TPU 0,2,4,6      β”‚
β”‚                       β”‚
β”‚                       β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                       β”‚
β”‚      TPU 1,3,5,7      β”‚
β”‚                       β”‚
β”‚                       β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
print(loss_jit(params, batch))
10.760103
step_size = 1e-5

for _ in range(30):
    grads = gradfun(params, batch)
    params = [(W - step_size * dW, b - step_size * db)
              for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.752466
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚TPU 0,2,4,6β”‚TPU 1,3,5,7β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β”‚           β”‚           β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                       β”‚
β”‚      TPU 0,2,4,6      β”‚
β”‚                       β”‚
β”‚                       β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                       β”‚
β”‚      TPU 1,3,5,7      β”‚
β”‚                       β”‚
β”‚                       β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()
10 loops, best of 10: 30.5 ms per loop

Sharp bits#

Generating random numbers#

JAX comes with a functional, deterministic random number generator. It underlies the various sampling functions in the jax.random module, such as jax.random.uniform.

JAX’s random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices.

However, the existing stable RNG implementation is not automatically partitionable, for historical reasons.

Consider the following example, where a function draws random uniform numbers and adds them to the input, elementwise:

@jax.jit
def f(key, x):
  numbers = jax.random.uniform(key, x.shape)
  return x + numbers

key = jax.random.key(42)
x_sharding = jax.sharding.PositionalSharding(jax.devices())
x = jax.device_put(jnp.arange(24), x_sharding)

On a partitioned input, the function f produces output that is also partitioned:

jax.debug.visualize_array_sharding(f(key, x))
β”Œβ”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”
β”‚ TPU 0 β”‚ TPU 1 β”‚ TPU 2 β”‚ TPU 3 β”‚ TPU 4 β”‚ TPU 5 β”‚ TPU 6 β”‚ TPU 7 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜

But if we inspect the compiled computation for f on this partitioned input, we see that it does involve some communication:

f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
Communicating? True

One way to work around this is to configure JAX with the experimental upgrade flag jax_threefry_partitionable. With the flag on, the β€œcollective permute” operation is now gone from the compiled computation:

jax.config.update('jax_threefry_partitionable', True)
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
Communicating? False

The output is still partitioned:

jax.debug.visualize_array_sharding(f(key, x))
β”Œβ”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”
β”‚ TPU 0 β”‚ TPU 1 β”‚ TPU 2 β”‚ TPU 3 β”‚ TPU 4 β”‚ TPU 5 β”‚ TPU 6 β”‚ TPU 7 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜

One caveat to the jax_threefry_partitionable option, however, is that the random values produced may be different than without the flag set, even though they were generated by the same random key:

jax.config.update('jax_threefry_partitionable', False)
print('Stable:')
print(f(key, x))
print()

jax.config.update('jax_threefry_partitionable', True)
print('Partitionable:')
print(f(key, x))
Stable:
[ 0.72503686  1.8532515   2.983416    3.083253    4.0332246   5.4782867
  6.1720605   7.6900277   8.602836    9.810046   10.861367   11.907651
 12.330483   13.456195   14.808557   15.960099   16.067581   17.739723
 18.335474   19.46401    20.390276   21.116539   22.858128   23.223194  ]

Partitionable:
[ 0.48870957  1.6797972   2.6162715   3.561016    4.4506445   5.585866
  6.0748096   7.775133    8.698959    9.818634   10.350306   11.87282
 12.925881   13.86013    14.477554   15.818481   16.711355   17.586697
 18.073738   19.777622   20.404566   21.119123   22.026257   23.63918   ]

In jax_threefry_partitionable mode, the JAX PRNG remains deterministic, but its implementation is new (and under development). The random values generated for a given key will be the same at a given JAX version (or a given commit on the main branch), but may vary across releases.