jax.numpy.flatnonzero#
- jax.numpy.flatnonzero(a, *, size=None, fill_value=None)[source]#
Return indices that are non-zero in the flattened version of a.
LAX-backend implementation of
numpy.flatnonzero()
.Because the size of the output of
nonzero
is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optionalsize
argument which must be specified statically forjnp.nonzero
to be used within some of JAX’s transformations.Original docstring below.
This is equivalent to
np.nonzero(np.ravel(a))[0]
.- Parameters:
a (array_like) – Input data.
size (int, optional) – If specified, the indices of the first
size
True elements will be returned. If there are fewer unique elements thansize
indicates, the return value will be padded withfill_value
.fill_value (array_like, optional) – When
size
is specified and there are fewer than the indicated number of elements, the remaining elements will be filled withfill_value
, which defaults to zero.
- Returns:
res – Output array, containing the indices of the elements of
a.ravel()
that are non-zero.- Return type:
ndarray