jax.numpy.flatnonzero(a, *, size=None, fill_value=None)[source]

Return indices that are non-zero in the flattened version of a.

LAX-backend implementation of flatnonzero().

Because the size of the output of nonzero is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which specifies the size of the output arrays: it must be specified statically for jnp.nonzero to be compiled with non-static operands. If specified, the first size nonzero elements will be returned; if there are fewer nonzero elements than size indicates, the result will be padded with fill_value, which defaults to zero. fill_value may be a scalar, or a tuple specifying the fill value in each dimension.

Original docstring below.

This is equivalent to np.nonzero(np.ravel(a))[0].


a (array_like) – Input data.


res – Output array, containing the indices of the elements of a.ravel() that are non-zero.

Return type