jax.numpy.count_nonzero

Contents

jax.numpy.count_nonzero#

jax.numpy.count_nonzero(a, axis=None, keepdims=False)[source]#

Counts the number of non-zero values in the array a.

LAX-backend implementation of numpy.count_nonzero().

Original docstring below.

The word “non-zero” is in reference to the Python 2.x built-in method __nonzero__() (renamed __bool__() in Python 3.x) of Python objects that tests an object’s “truthfulness”. For example, any number is considered truthful if it is nonzero, whereas any string is considered truthful if it is not the empty string. Thus, this function (recursively) counts how many elements in a (and in sub-arrays thereof) have their __nonzero__() or __bool__() method evaluated to True.

Parameters:
  • a (array_like) – The array for which to count non-zeros.

  • axis (int or tuple, optional) – Axis or tuple of axes along which to count non-zeros. Default is None, meaning that non-zeros will be counted along a flattened version of a.

  • keepdims (bool, optional) – If this is set to True, the axes that are counted are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.

Returns:

count – Number of non-zero values in the array along a given axis. Otherwise, the total number of non-zero values in the array is returned.

Return type:

int or array of int