- jax.numpy.nonzero(a, *, size=None, fill_value=None)¶
Return the indices of the elements that are non-zero.
LAX-backend implementation of
Because the size of the output of
nonzerois data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which specifies the size of the output arrays: it must be specified statically for
jnp.nonzeroto be compiled with non-static operands. If specified, the first size nonzero elements will be returned; if there are fewer nonzero elements than size indicates, the result will be padded with
fill_value, which defaults to zero.
fill_valuemay be a scalar, or a tuple specifying the fill value in each dimension.
Original docstring below.
Returns a tuple of arrays, one for each dimension of a, containing the indices of the non-zero elements in that dimension. The values in a are always tested and returned in row-major, C-style order.
To group the indices by element, rather than dimension, use argwhere, which returns a row for each non-zero element.
When called on a zero-d array or scalar,
nonzero(a)is treated as
Deprecated since version 1.17.0: Use atleast_1d explicitly if this behavior is deliberate.
a (array_like) – Input array.
tuple_of_arrays – Indices of elements that are non-zero.
- Return type