jax.numpy.argwhere

jax.numpy.argwhere(a)[source]

Find the indices of array elements that are non-zero, grouped by element.

LAX-backend implementation of argwhere().

Original docstring below.

Parameters

a (array_like) – Input data.

Returns

index_array – Indices of elements that are non-zero. Indices are grouped by element. This array will have shape (N, a.ndim) where N is the number of non-zero items.

Return type

(N, a.ndim) ndarray