jax.Array migration#

yashkatariya@

TL;DR#

JAX switched its default array implementation to the new jax.Array as of version 0.4.1. This guide explains the reasoning behind this, the impact it might have on your code, and how to (temporarily) switch back to the old behavior.

What’s going on?#

jax.Array is a unified array type that subsumes DeviceArray, ShardedDeviceArray, and GlobalDeviceArray types in JAX. The jax.Array type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unify jit and pjit. If your code doesn’t mention DeviceArray vs ShardedDeviceArray vs GlobalDeviceArray, no changes are needed. But code that depends on details of these separate classes may need to be tweaked to work with the unified jax.Array

After the migration is complete jax.Array will be the only type of array in JAX.

This doc explains how to migrate existing codebases to jax.Array. For more information on using jax.Array and JAX parallelism APIs, see the Distributed arrays and automatic parallelization tutorial.

How to enable jax.Array?#

You can enable jax.Array by:

  • setting the shell environment variable JAX_ARRAY to something true-like (e.g., 1);

  • setting the boolean flag jax_array to something true-like if your code parses flags with absl;

  • using this statement at the top of your main file:

    import jax
    jax.config.update('jax_array', True)
    

How do I know if jax.Array broke my code?#

The easiest way to tell if jax.Array is responsible for any problems is to disable jax.Array and see if the issues go away.

How can I disable jax.Array for now?#

Through March 15, 2023 it will be possible to disable jax.Array by:

  • setting the shell environment variable JAX_ARRAY to something falsey (e.g., 0);

  • setting the boolean flag jax_array to something falsey if your code parses flags with absl;

  • using this statement at the top of your main file:

    import jax
    jax.config.update('jax_array', False)
    

Why create jax.Array?#

Currently JAX has three types; DeviceArray, ShardedDeviceArray and GlobalDeviceArray. jax.Array merges these three types and cleans up JAX’s internals while adding new parallelism features.

We also introduce a new Sharding abstraction that describes how a logical Array is physically sharded out across one or more devices, such as TPUs or GPUs. The change also upgrades, simplifies and merges the parallelism features of pjit into jit. Functions decorated with jit will be able to operate over sharded arrays without copying data onto a single device.

Features you get with jax.Array:

  • C++ pjit dispatch path

  • Op-by-op parallelism (even if the array distributed across multiple devices across multiple hosts)

  • Simpler batch data parallelism with pjit/jit.

  • Ways to create Shardings that are not necessarily consisting of a mesh and partition spec. Can fully utilize the flexibility of OpSharding if you want or any other Sharding that you want.

  • and many more

Example:

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
x = jnp.arange(8)

# Let's say there are 8 devices in jax.devices()
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x'))

sharded_x = jax.device_put(x, sharding)

# `matmul_sharded_x` and `sin_sharded_x` are sharded. `jit` is able to operate over a
# sharded array without copying data to a single device.
matmul_sharded_x = sharded_x @ sharded_x.T
sin_sharded_x = jnp.sin(sharded_x)

# Even jnp.copy preserves the sharding on the output.
copy_sharded_x = jnp.copy(sharded_x)

# double_out is also sharded
double_out = jax.jit(lambda x: x * 2)(sharded_x)

What issues can arise when jax.Array is switched on?#

New public type named jax.Array#

All isinstance(..., jnp.DeviceArray) or isinstance(.., jax.xla.DeviceArray) and other variants of DeviceArray should be switched to using isinstance(..., jax.Array).

Since jax.Array can represent DA, SDA and GDA, you can differentiate those 3 types in jax.Array via:

  • x.is_fully_addressable and len(x.sharding.device_set) == 1 – this means that jax.Array is like a DA

  • x.is_fully_addressable and (len(x.sharding.device_set) > 1 – this means that jax.Array is like a SDA

  • not x.is_fully_addressable – this means that jax.Array is like a GDA and spans across multiple processes

For ShardedDeviceArray, you can move isinstance(..., pxla.ShardedDeviceArray) to isinstance(..., jax.Array) and x.is_fully_addressable and len(x.sharding.device_set) > 1.

In general it is not possible to differentiate a ShardedDeviceArray on 1 device from any other kind of single-device Array.

GDA’s API name changes#

GDA’s local_shards and local_data have been deprecated.

Please use addressable_shards and addressable_data which are compatible with jax.Array and GDA.

Creating jax.Array#

All JAX functions will output jax.Array when the jax_array flag is True. If you were using GlobalDeviceArray.from_callback or make_sharded_device_array or make_device_array functions to explicitly create the respective JAX data types, you will need to switch them to use jax.make_array_from_callback() or jax.make_array_from_single_device_arrays().

For GDA:

GlobalDeviceArray.from_callback(shape, mesh, pspec, callback) can become jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback) in a 1:1 switch.

If you were using the raw GDA constructor to create GDAs, then do this:

GlobalDeviceArray(shape, mesh, pspec, buffers) can become jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(mesh, pspec), buffers)

For SDA:

make_sharded_device_array(aval, sharding_spec, device_buffers, indices) can become jax.make_array_from_single_device_arrays(shape, sharding, device_buffers).

To decide what the sharding should be, it depends on why you were creating the SDAs:

If it was created to give as an input to pmap, then sharding can be: jax.sharding.PmapSharding(devices, sharding_spec).

If it was created to give as an input to pjit, then sharding can be jax.sharding.NamedSharding(mesh, pspec).

Breaking change for pjit after switching to jax.Array for host local inputs#

If you are exclusively using GDA arguments to pjit, you can skip this section! 🎉

With jax.Array enabled, all inputs to pjit must be globally shaped. This is a breaking change from the previous behavior where pjit would concatenate process-local arguments into a global value; this concatenation no longer occurs.

Why are we making this breaking change? Each array now says explicitly how its local shards fit into a global whole, rather than leaving it implicit. The more explicit representation also unlocks additional flexibility, for example the use of non-contiguous meshes with pjit which can improve efficiency on some TPU models.

Running multi-process pjit computation and passing host-local inputs when jax.Array is enabled can lead to an error similar to this:

Example:

Mesh = {'x': 2, 'y': 2, 'z': 2} and host local input shape == (4,) and pspec = P(('x', 'y', 'z'))

Since pjit doesn’t lift host local shapes to global shapes with jax.Array, you get the following error:

Note: You will only see this error if your host local shape is smaller than the shape of the mesh.

ValueError: One of pjit arguments was given the sharding of
NamedSharding(mesh={'x': 2, 'y': 2, 'chips': 2}, partition_spec=PartitionSpec(('x', 'y', 'chips'),)),
which implies that the global size of its dimension 0 should be divisible by 8,
but it is equal to 4

The error makes sense because you can’t shard dimension 0, 8 ways when the value on dimension 0 is 4.

How can you migrate if you still pass host local inputs to pjit? We are providing transitional APIs to help you migrate:

Note: You don’t need these utilities if you run your pjitted computation on a single process.

from jax.experimental import multihost_utils

global_inps = multihost_utils.host_local_array_to_global_array(
    local_inputs, mesh, in_pspecs)

global_outputs = pjit(f, in_shardings=in_pspecs,
                      out_shardings=out_pspecs)(global_inps)

local_outs = multihost_utils.global_array_to_host_local_array(
    global_outputs, mesh, out_pspecs)

host_local_array_to_global_array is a type cast that looks at a value with only local shards and changes its local shape to the shape that pjit would have previously assumed if that value was passed before the change.

Passing in fully replicated inputs i.e. same shape on each process with P(None) as in_axis_resources is still supported. In this case you do not have to use host_local_array_to_global_array because the shape is already global.

key = jax.random.PRNGKey(1)

# As you can see, using host_local_array_to_global_array is not required since in_axis_resources says
# that the input is fully replicated via P(None)
pjit(f, in_shardings=None, out_shardings=None)(key)

# Mixing inputs
global_inp = multihost_utils.host_local_array_to_global_array(
    local_inp, mesh, P('data'))
global_out = pjit(f, in_shardings=(P(None), P('data')),
                  out_shardings=...)(key, global_inp)

FROM_GDA and jax.Array#

If you were using FROM_GDA in in_axis_resources argument to pjit, then with jax.Array there is no need to pass anything to in_axis_resources as jax.Array will follow computation follows sharding semantics.

For example:

pjit(f, in_shardings=FROM_GDA, out_shardings=...) can be replaced by pjit(f, out_shardings=...)

If you have PartitionSpecs mixed in with FROM_GDA for inputs like numpy arrays, etc, then use host_local_array_to_global_array to convert them to jax.Array.

For example:

If you had this:

pjitted_f = pjit(
    f, in_shardings=(FROM_GDA, P('x'), FROM_GDA, P(None)),
    out_shardings=...)
pjitted_f(gda1, np_array1, gda2, np_array2)

then you can replace it with:


pjitted_f = pjit(f, out_shardings=...)

array2, array3 = multihost_utils.host_local_array_to_global_array(
    (np_array1, np_array2), mesh, (P('x'), P(None)))

pjitted_f(array1, array2, array3, array4)

live_buffers replaced with live_arrays#

live_buffers attribute on jax Device has been deprecated. Please use jax.live_arrays() instead which is compatible with jax.Array.

Handling of host local inputs to pjit like batch, etc#

If you are passing host local inputs to pjit in a multi-process environment, then please use multihost_utils.host_local_array_to_global_array to convert the batch to a global jax.Array and then pass that to pjit.

The most common example of such a host local input is a batch of input data.

This will work for any host local input (not just a batch of input data).

from jax.experimental import multihost_utils

batch = multihost_utils.host_local_array_to_global_array(
    batch, mesh, batch_partition_spec)

See the pjit section above for more details about this change and more examples.

RecursionError: Recursively calling jit#

This happens when some part of your code has jax.Array disabled and then you enable it only for some other part. For example, if you use some third_party code which has jax.Array disabled and you get a DeviceArray from that library and then you enable jax.Array in your library and pass that DeviceArray to JAX functions, it will lead to a RecursionError.

This error should go away when jax.Array is enabled by default so that all libraries return jax.Array unless they explicitly disable it.