jax.numpy.flatnonzero

Contents

jax.numpy.flatnonzero#

jax.numpy.flatnonzero(a, *, size=None, fill_value=None)[source]#

Return indices of nonzero elements in a flattened array

JAX implementation of numpy.flatnonzero().

jnp.flatnonzero(x) is equivalent to nonzero(ravel(a))[0]. For a full discussion of the parameters to this function, refer to jax.numpy.nonzero().

Parameters:
  • a (ArrayLike) – N-dimensional array.

  • size (int | None) – optional static integer specifying the number of nonzero entries to return. See jax.numpy.nonzero() for more discussion of this parameter.

  • fill_value (None | ArrayLike | tuple[ArrayLike, ...]) – optional padding value when size is specified. Defaults to 0. See jax.numpy.nonzero() for more discussion of this parameter.

Returns:

Array containing the indices of each nonzero value in the flattened array.

Return type:

Array

Examples

>>> x = jnp.array([[0, 5, 0],
...                [6, 0, 8]])
>>> jnp.flatnonzero(x)
Array([1, 3, 5], dtype=int32)

This is equivalent to calling nonzero() on the flattened array, and extracting the first entry in the resulting tuple:

>>> jnp.nonzero(x.ravel())[0]
Array([1, 3, 5], dtype=int32)

The returned indices can be used to extract nonzero entries from the flattened array:

>>> indices = jnp.flatnonzero(x)
>>> x.ravel()[indices]
Array([5, 6, 8], dtype=int32)