jax.numpy.nansumΒΆ

jax.numpy.nansum(a, axis=None, dtype=None, out=None, keepdims=None)[source]ΒΆ

Return the sum of array elements over a given axis treating Not a

LAX-backend implementation of nansum().

Original docstring below.

Numbers (NaNs) as zero.

In NumPy versions <= 1.9.0 Nan is returned for slices that are all-NaN or empty. In later versions zero is returned.

Parameters
  • a (array_like) – Array containing numbers whose sum is desired. If a is not an array, a conversion is attempted.

  • axis ({int, tuple of int, None}, optional) – Axis or axes along which the sum is computed. The default is to compute the sum of the flattened array.

  • dtype (data-type, optional) – The type of the returned array and of the accumulator in which the elements are summed. By default, the dtype of a is used. An exception is when a has an integer type with less precision than the platform (u)intp. In that case, the default will be either (u)int32 or (u)int64 depending on whether the platform is 32 or 64 bits. For inexact inputs, dtype must be inexact.

  • keepdims – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original a.

Returns

nansum – A new array holding the result is returned unless out is specified, in which it is returned. The result has the same size as a, and the same shape as a if axis is not None or a is a 1-d array.

Return type

ndarray.