jax.numpy.flatnonzero#
- jax.numpy.flatnonzero(a, *, size=None, fill_value=None)[source]#
Return indices of nonzero elements in a flattened array
JAX implementation of
numpy.flatnonzero()
.jnp.flatnonzero(x)
is equivalent tononzero(ravel(a))[0]
. For a full discussion of the parameters to this function, refer tojax.numpy.nonzero()
.- Parameters:
a (ArrayLike) – N-dimensional array.
size (int | None | None) – optional static integer specifying the number of nonzero entries to return. See
jax.numpy.nonzero()
for more discussion of this parameter.fill_value (None | ArrayLike | tuple[ArrayLike, ...] | None) – optional padding value when
size
is specified. Defaults to 0. Seejax.numpy.nonzero()
for more discussion of this parameter.
- Returns:
Array containing the indices of each nonzero value in the flattened array.
- Return type:
See also
Examples
>>> x = jnp.array([[0, 5, 0], ... [6, 0, 8]]) >>> jnp.flatnonzero(x) Array([1, 3, 5], dtype=int32)
This is equivalent to calling
nonzero()
on the flattened array, and extracting the first entry in the resulting tuple:>>> jnp.nonzero(x.ravel())[0] Array([1, 3, 5], dtype=int32)
The returned indices can be used to extract nonzero entries from the flattened array:
>>> indices = jnp.flatnonzero(x) >>> x.ravel()[indices] Array([5, 6, 8], dtype=int32)