jax.make_array_from_callback

jax.make_array_from_callback#

jax.make_array_from_callback(shape, sharding, data_callback)[source]#

Returns a jax.Array via data fetched from data_callback.

data_callback is used to fetch the data for each addressable shard of the returned jax.Array. This function must return concrete arrays, meaning that make_array_from_callback has limited compatibility with JAX transformations like jit() or vmap().

Parameters:
  • shape (Shape) – Shape of the jax.Array.

  • sharding (Sharding) – A Sharding instance which describes how the jax.Array is laid out across devices.

  • data_callback (Callable[[Index | None], ArrayLike]) – Callback that takes indices into the global array value as input and returns the corresponding data of the global array value. The data can be returned as any array-like object, e.g. a numpy.ndarray.

Return type:

ArrayImpl

Returns:

A jax.Array via data fetched from data_callback.

Example

>>> import math
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> input_shape = (8, 8)
>>> global_input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
...
>>> def cb(index):
...  return global_input_data[index]
...
>>> arr = jax.make_array_from_callback(input_shape, inp_sharding, cb)
>>> arr.addressable_data(0).shape
(4, 2)