jax.numpy.argwhere

Contents

jax.numpy.argwhere#

jax.numpy.argwhere(a, *, size=None, fill_value=None)[source]#

Find the indices of array elements that are non-zero, grouped by element.

LAX-backend implementation of numpy.argwhere().

Because the size of the output of argwhere is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which must be specified statically for jnp.argwhere to be used within some of JAX’s transformations.

Original docstring below.

Parameters:
  • a (array_like) – Input data.

  • size (int, optional) – If specified, the indices of the first size True elements will be returned. If there are fewer results than size indicates, the return value will be padded with fill_value.

  • fill_value (array_like, optional) – When size is specified and there are fewer than the indicated number of elements, the remaining elements will be filled with fill_value, which defaults to zero.

Returns:

index_array – Indices of elements that are non-zero. Indices are grouped by element. This array will have shape (N, a.ndim) where N is the number of non-zero items.

Return type:

(N, a.ndim) ndarray