jax.numpy.nonzero

Contents

jax.numpy.nonzero#

jax.numpy.nonzero(a, *, size=None, fill_value=None)[source]#

Return indices of nonzero elements of an array.

JAX implementation of numpy.nonzero().

Because the size of the output of nonzero is data-dependent, the function is not compatible with JIT and other transformations. The JAX version adds the optional size argument which must be specified statically for jnp.nonzero to be used within JAX’s transformations.

Parameters:
  • a (ArrayLike) – N-dimensional array.

  • size (int | None) – optional static integer specifying the number of nonzero entries to return. If there are more nonzero elements than the specified size, then indices will be truncated at the end. If there are fewer nonzero elements than the specified size, then indices will be padded with fill_value, which defaults to zero.

  • fill_value (None | ArrayLike | tuple[ArrayLike, ...]) – optional padding value when size is specified. Defaults to 0.

Returns:

Tuple of JAX Arrays of length a.ndim, containing the indices of each nonzero value.

Return type:

tuple[Array, …]

See also

Examples

One-dimensional array returns a length-1 tuple of indices:

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jnp.nonzero(x)
(Array([1, 3, 5], dtype=int32),)

Two-dimensional array returns a length-2 tuple of indices:

>>> x = jnp.array([[0, 5, 0],
...                [6, 0, 7]])
>>> jnp.nonzero(x)
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))

In either case, the resulting tuple of indices can be used directly to extract the nonzero values:

>>> indices = jnp.nonzero(x)
>>> x[indices]
Array([5, 6, 7], dtype=int32)

The output of nonzero has a dynamic shape, because the number of returned indices depends on the contents of the input array. As such, it is incompatible with JIT and other JAX transformations:

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jax.jit(jnp.nonzero)(x)  
Traceback (most recent call last):
  ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.

This can be addressed by passing a static size parameter to specify the desired output shape:

>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size')
>>> nonzero_jit(x, size=3)
(Array([1, 3, 5], dtype=int32),)

If size does not match the true size, the result will be either truncated or padded:

>>> nonzero_jit(x, size=2)  # size < 3: indices are truncated
(Array([1, 3], dtype=int32),)
>>> nonzero_jit(x, size=5)  # size > 3: indices are padded with zeros.
(Array([1, 3, 5, 0, 0], dtype=int32),)

You can specify a custom fill value for the padding using the fill_value argument:

>>> nonzero_jit(x, size=5, fill_value=len(x))
(Array([1, 3, 5, 6, 6], dtype=int32),)