jax.numpy.nonzero#

jax.numpy.nonzero(a, *, size=None, fill_value=None)[source]#

Return the indices of the elements that are non-zero.

LAX-backend implementation of numpy.nonzero().

Because the size of the output of nonzero is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which must be specified statically for jnp.nonzero to be used within some of JAX’s transformations.

Original docstring below.

Returns a tuple of arrays, one for each dimension of a, containing the indices of the non-zero elements in that dimension. The values in a are always tested and returned in row-major, C-style order.

To group the indices by element, rather than dimension, use argwhere, which returns a row for each non-zero element.

Note

When called on a zero-d array or scalar, nonzero(a) is treated as nonzero(atleast_1d(a)).

Deprecated since version 1.17.0: Use atleast_1d explicitly if this behavior is deliberate.

Parameters:
  • a (array_like) – Input array.

  • size (int, optional) – If specified, the indices of the first size True elements will be returned. If there are fewer unique elements than size indicates, the return value will be padded with fill_value.

  • fill_value (array_like, optional) – When size is specified and there are fewer than the indicated number of elements, the remaining elements will be filled with fill_value, which defaults to zero.

Returns:

tuple_of_arrays – Indices of elements that are non-zero.

Return type:

tuple