jax.numpy.count_nonzero(a, axis=None, keepdims=False)[source]

Counts the number of non-zero values in the array a.

LAX-backend implementation of count_nonzero(). Original docstring below.

The word “non-zero” is in reference to the Python 2.x built-in method __nonzero__() (renamed __bool__() in Python 3.x) of Python objects that tests an object’s “truthfulness”. For example, any number is considered truthful if it is nonzero, whereas any string is considered truthful if it is not the empty string. Thus, this function (recursively) counts how many elements in a (and in sub-arrays thereof) have their __nonzero__() or __bool__() method evaluated to True.

  • a (array_like) – The array for which to count non-zeros.

  • axis (int or tuple, optional) – Axis or tuple of axes along which to count non-zeros. Default is None, meaning that non-zeros will be counted along a flattened version of a.


count – Number of non-zero values in the array along a given axis. Otherwise, the total number of non-zero values in the array is returned.

Return type

int or array of int

See also


Return the coordinates of all the non-zero values.


>>> np.count_nonzero(np.eye(4))
>>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]])
>>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]], axis=0)
array([1, 1, 1, 1, 1])
>>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]], axis=1)
array([2, 3])