jax.numpy.flatnonzero

jax.numpy.flatnonzero(a)[source]

Return indices that are non-zero in the flattened version of a.

LAX-backend implementation of flatnonzero().

Original docstring below.

This is equivalent to np.nonzero(np.ravel(a))[0].

Parameters

a (array_like) – Input data.

Returns

res – Output array, containing the indices of the elements of a.ravel() that are non-zero.

Return type

ndarray