# jax.numpy.nonzero¶

jax.numpy.nonzero(a)[source]

Return the indices of the elements that are non-zero.

LAX-backend implementation of nonzero(). At present, JAX does not support JIT-compilation of jax.numpy.nonzero() because its output shape is data-dependent.

Original docstring below.

Returns a tuple of arrays, one for each dimension of a, containing the indices of the non-zero elements in that dimension. The values in a are always tested and returned in row-major, C-style order.

To group the indices by element, rather than dimension, use argwhere, which returns a row for each non-zero element.

Note

When called on a zero-d array or scalar, nonzero(a) is treated as nonzero(atleast1d(a)).

Deprecated since version 1.17.0: Use atleast1d explicitly if this behavior is deliberate.

Parameters

a (array_like) – Input array.

Returns

tuple_of_arrays – Indices of elements that are non-zero.

Return type

tuple

flatnonzero()

Return indices that are non-zero in the flattened version of the input array.

ndarray.nonzero()

Equivalent ndarray method.

count_nonzero()

Counts the number of non-zero elements in the input array.

Notes

While the nonzero values can be obtained with a[nonzero(a)], it is recommended to use x[x.astype(bool)] or x[x != 0] instead, which will correctly handle 0-d arrays.

Examples

>>> x = np.array([[3, 0, 0], [0, 4, 0], [5, 6, 0]])
>>> x
array([[3, 0, 0],
[0, 4, 0],
[5, 6, 0]])
>>> np.nonzero(x)
(array([0, 1, 2, 2]), array([0, 1, 0, 1]))

>>> x[np.nonzero(x)]
array([3, 4, 5, 6])
>>> np.transpose(np.nonzero(x))
array([[0, 0],
[1, 1],
[2, 0],
[2, 1]])


A common use for nonzero is to find the indices of an array, where a condition is True. Given an array a, the condition a > 3 is a boolean array and since False is interpreted as 0, np.nonzero(a > 3) yields the indices of the a where the condition is true.

>>> a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> a > 3
array([[False, False, False],
[ True,  True,  True],
[ True,  True,  True]])
>>> np.nonzero(a > 3)
(array([1, 1, 1, 2, 2, 2]), array([0, 1, 2, 0, 1, 2]))


Using this result to index a is equivalent to using the mask directly:

>>> a[np.nonzero(a > 3)]
array([4, 5, 6, 7, 8, 9])
>>> a[a > 3]  # prefer this spelling
array([4, 5, 6, 7, 8, 9])


nonzero can also be called as a method of the array.

>>> (a > 3).nonzero()
(array([1, 1, 1, 2, 2, 2]), array([0, 1, 2, 0, 1, 2]))