- jax.numpy.nonzero(a, *, size=None, fill_value=None)#
Return the indices of the elements that are non-zero.
LAX-backend implementation of
Because the size of the output of
nonzerois data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional
sizeargument which must be specified statically for
jnp.nonzeroto be used within some of JAX’s transformations.
Original docstring below.
Returns a tuple of arrays, one for each dimension of a, containing the indices of the non-zero elements in that dimension. The values in a are always tested and returned in row-major, C-style order.
To group the indices by element, rather than dimension, use argwhere, which returns a row for each non-zero element.
When called on a zero-d array or scalar,
nonzero(a)is treated as
Deprecated since version 1.17.0: Use atleast_1d explicitly if this behavior is deliberate.
a (array_like) – Input array.
size (int, optional) – If specified, the indices of the first
sizeTrue elements will be returned. If there are fewer unique elements than
sizeindicates, the return value will be padded with
fill_value (array_like, optional) – When
sizeis specified and there are fewer than the indicated number of elements, the remaining elements will be filled with
fill_value, which defaults to zero.
tuple_of_arrays – Indices of elements that are non-zero.
- Return type