jax.numpy.unravel_index¶
-
jax.numpy.
unravel_index
(indices, shape)[source]¶ Converts a flat index or array of flat indices into a tuple of coordinate arrays.
LAX-backend implementation of
unravel_index()
. Unlike numpy’s implementation of unravel_index, negative indices are accepted and out-of-bounds indices are clipped.Original docstring below.
unravel_index(indices, shape, order=’C’)
- Parameters
indices (array_like) – An integer array whose elements are indices into the flattened version of an array of dimensions
shape
. Before version 1.6.0, this function accepted just one index value.shape (tuple of ints) – The shape of the array to use for unraveling
indices
.
- Returns
unraveled_coords – Each array in the tuple has the same shape as the
indices
array.- Return type
tuple of ndarray
See also
Examples
>>> np.unravel_index([22, 41, 37], (7,6)) (array([3, 6, 6]), array([4, 5, 1])) >>> np.unravel_index([31, 41, 13], (7,6), order='F') (array([3, 6, 6]), array([4, 5, 1]))
>>> np.unravel_index(1621, (6,7,8,9)) (3, 1, 4, 1)