jax.numpy.nonzero(a, *, size=None)[source]

Return the indices of the elements that are non-zero.

LAX-backend implementation of nonzero().

Because the size of the output of nonzero is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which specifies the size of the output arrays: it must be specified statically for jnp.nonzero to be traced. If specified, the first size nonzero elements will be returned; if there are fewer nonzero elements than size indicates, the index arrays will be zero-padded.

Original docstring below.

Returns a tuple of arrays, one for each dimension of a, containing the indices of the non-zero elements in that dimension. The values in a are always tested and returned in row-major, C-style order.

To group the indices by element, rather than dimension, use argwhere, which returns a row for each non-zero element.


When called on a zero-d array or scalar, nonzero(a) is treated as nonzero(atleast1d(a)).

Deprecated since version 1.17.0: Use atleast1d explicitly if this behavior is deliberate.


a (array_like) – Input array.


tuple_of_arrays – Indices of elements that are non-zero.

Return type