jax.numpy.argwhere

jax.numpy.argwhere(a, *, size=None)[source]

Find the indices of array elements that are non-zero, grouped by element.

LAX-backend implementation of argwhere().

Because the size of the output of argwhere is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument, which specifies the size of the leading dimension of the output - it must be specified statically for jnp.argwhere to be traced. If size is specified, the indices of the first size True elements will be returned; if there are fewer nonzero elements than size indicates, the index arrays will be zero-padded.

Original docstring below.

Parameters

a (array_like) – Input data.

Returns

index_array – Indices of elements that are non-zero. Indices are grouped by element. This array will have shape (N, a.ndim) where N is the number of non-zero items.

Return type

(N, a.ndim) ndarray