jax.numpy.flatnonzero

jax.numpy.flatnonzero(a)[source]

Return indices that are non-zero in the flattened version of a.

LAX-backend implementation of flatnonzero(). Original docstring below.

This is equivalent to np.nonzero(np.ravel(a))[0].

Parameters

a (array_like) – Input data.

Returns

res – Output array, containing the indices of the elements of a.ravel() that are non-zero.

Return type

ndarray

See also

nonzero()

Return the indices of the non-zero elements of the input array.

ravel()

Return a 1-D array containing the elements of the input array.

Examples

>>> x = np.arange(-2, 3)
>>> x
array([-2, -1,  0,  1,  2])
>>> np.flatnonzero(x)
array([0, 1, 3, 4])

Use the indices of the non-zero elements as an index array to extract these elements:

>>> x.ravel()[np.flatnonzero(x)]
array([-2, -1,  1,  2])