jax.numpy.unravel_index

jax.numpy.unravel_index(indices, shape)[source]
Converts a flat index or array of flat indices into a tuple

of coordinate arrays.

LAX-backend implementation of unravel_index(). Unlike numpy’s implementation of unravel_index, negative indices are accepted and out-of-bounds indices are clipped.

Original docstring below.

unravel_index(indices, shape, order=’C’)

Returns
unraveled_coordstuple of ndarray

Each array in the tuple has the same shape as the indices array.

ravel_multi_index

>>> np.unravel_index([22, 41, 37], (7,6))
(array([3, 6, 6]), array([4, 5, 1]))
>>> np.unravel_index([31, 41, 13], (7,6), order='F')
(array([3, 6, 6]), array([4, 5, 1]))
>>> np.unravel_index(1621, (6,7,8,9))
(3, 1, 4, 1)