jax.numpy.flatnonzero(a, *, size=None, fill_value=None)[source]#

Return indices of nonzero elements in a flattened array

JAX implementation of numpy.flatnonzero().

jnp.flatnonzero(x) is equivalent to nonzero(ravel(a))[0]. For a full discussion of the parameters to this function, refer to jax.numpy.nonzero().

  • a (ArrayLike) – N-dimensional array.

  • size (int | None) – optional static integer specifying the number of nonzero entries to return. See jax.numpy.nonzero() for more discussion of this parameter.

  • fill_value (None | ArrayLike | tuple[ArrayLike, ...]) – optional padding value when size is specified. Defaults to 0. See jax.numpy.nonzero() for more discussion of this parameter.


Array containing the indices of each nonzero value in the flattened array.

Return type:



>>> x = jnp.array([[0, 5, 0],
...                [6, 0, 8]])
>>> jnp.flatnonzero(x)
Array([1, 3, 5], dtype=int32)

This is equivalent to calling nonzero() on the flattened array, and extracting the first entry in the resulting tuple:

>>> jnp.nonzero(x.ravel())[0]
Array([1, 3, 5], dtype=int32)

The returned indices can be used to extract nonzero entries from the flattened array:

>>> indices = jnp.flatnonzero(x)
>>> x.ravel()[indices]
Array([5, 6, 8], dtype=int32)