jax.numpy.unravel_index

Contents

jax.numpy.unravel_index#

jax.numpy.unravel_index(indices, shape)[source]#

Converts a flat index or array of flat indices into a tuple

LAX-backend implementation of numpy.unravel_index().

Unlike numpy’s implementation of unravel_index, negative indices are accepted and out-of-bounds indices are clipped into the valid range.

Original docstring below.

of coordinate arrays.

Parameters:
  • indices (array_like) – An integer array whose elements are indices into the flattened version of an array of dimensions shape. Before version 1.6.0, this function accepted just one index value.

  • shape (tuple of ints) –

    The shape of the array to use for unraveling indices.

    Changed in version 1.16.0: Renamed from dims to shape.

Returns:

unraveled_coords – Each array in the tuple has the same shape as the indices array.

Return type:

tuple of ndarray