jax.numpy.nonzero#
- jax.numpy.nonzero(a, *, size=None, fill_value=None)[source]#
Return the indices of the elements that are non-zero.
LAX-backend implementation of
numpy.nonzero()
.Because the size of the output of
nonzero
is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optionalsize
argument which must be specified statically forjnp.nonzero
to be used within some of JAX’s transformations.Original docstring below.
Returns a tuple of arrays, one for each dimension of a, containing the indices of the non-zero elements in that dimension. The values in a are always tested and returned in row-major, C-style order.
To group the indices by element, rather than dimension, use argwhere, which returns a row for each non-zero element.
Note
When called on a zero-d array or scalar,
nonzero(a)
is treated asnonzero(atleast_1d(a))
.Deprecated since version 1.17.0: Use atleast_1d explicitly if this behavior is deliberate.
- Parameters:
a (array_like) – Input array.
size (int, optional) – If specified, the indices of the first
size
True elements will be returned. If there are fewer unique elements thansize
indicates, the return value will be padded withfill_value
.fill_value (array_like, optional) – When
size
is specified and there are fewer than the indicated number of elements, the remaining elements will be filled withfill_value
, which defaults to zero.
- Returns:
tuple_of_arrays – Indices of elements that are non-zero.
- Return type: