jax.numpy.flatnonzero#

jax.numpy.flatnonzero(a, *, size=None, fill_value=None)[source]#

Return indices that are non-zero in the flattened version of a.

LAX-backend implementation of numpy.flatnonzero().

Because the size of the output of nonzero is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which must be specified statically for jnp.nonzero to be used within some of JAX’s transformations.

Original docstring below.

This is equivalent to np.nonzero(np.ravel(a))[0].

Parameters:
  • a (array_like) – Input data.

  • size (int, optional) – If specified, the indices of the first size True elements will be returned. If there are fewer unique elements than size indicates, the return value will be padded with fill_value.

  • fill_value (array_like, optional) – When size is specified and there are fewer than the indicated number of elements, the remaining elements will be filled with fill_value, which defaults to zero.

Returns:

res – Output array, containing the indices of the elements of a.ravel() that are non-zero.

Return type:

ndarray