jax.numpy.flatnonzero

jax.numpy.flatnonzero(a, *, size=None)[source]

Return indices that are non-zero in the flattened version of a.

LAX-backend implementation of flatnonzero().

Because the size of the output of nonzero is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which specifies the size of the output arrays: it must be specified statically for jnp.nonzero to be traced. If specified, the first size nonzero elements will be returned; if there are fewer nonzero elements than size indicates, the result will be padded with fill_value, which defaults to zero.

Original docstring below.

This is equivalent to np.nonzero(np.ravel(a))[0].

Parameters

a (array_like) – Input data.

Returns

res – Output array, containing the indices of the elements of a.ravel() that are non-zero.

Return type

ndarray