jax.numpy.argwhere#
- jax.numpy.argwhere(a, *, size=None, fill_value=None)[source]#
Find the indices of array elements that are non-zero, grouped by element.
LAX-backend implementation of
numpy.argwhere()
.Because the size of the output of
argwhere
is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optionalsize
argument which must be specified statically forjnp.argwhere
to be used within some of JAX’s transformations.Original docstring below.
- Parameters
a (array_like) – Input data.
size (int, optional) – If specified, the indices of the first
size
True elements will be returned. If there are fewer results thansize
indicates, the return value will be padded withfill_value
.fill_value (array_like, optional) – When
size
is specified and there are fewer than the indicated number of elements, the remaining elements will be filled withfill_value
, which defaults to zero.
- Returns
index_array – Indices of elements that are non-zero. Indices are grouped by element. This array will have shape
(N, a.ndim)
whereN
is the number of non-zero items.- Return type
(N, a.ndim) ndarray