jax.numpy.nansum

Contents

jax.numpy.nansum#

jax.numpy.nansum(a, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None)[source]#

Return the sum of array elements over a given axis treating Not a

LAX-backend implementation of numpy.nansum().

Original docstring below.

Numbers (NaNs) as zero.

In NumPy versions <= 1.9.0 Nan is returned for slices that are all-NaN or empty. In later versions zero is returned.

Parameters:
  • a (array_like) – Array containing numbers whose sum is desired. If a is not an array, a conversion is attempted.

  • axis ({int, tuple of int, None}, optional) – Axis or axes along which the sum is computed. The default is to compute the sum of the flattened array.

  • dtype (data-type, optional) – The type of the returned array and of the accumulator in which the elements are summed. By default, the dtype of a is used. An exception is when a has an integer type with less precision than the platform (u)intp. In that case, the default will be either (u)int32 or (u)int64 depending on whether the platform is 32 or 64 bits. For inexact inputs, dtype must be inexact.

  • keepdims (bool, optional) –

    If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original a.

    If the value is anything but the default, then keepdims will be passed through to the mean or sum methods of sub-classes of ndarray. If the sub-classes methods does not implement keepdims any exceptions will be raised.

  • initial (scalar, optional) – Starting value for the sum. See ~numpy.ufunc.reduce for details.

  • where (array_like of bool, optional) – Elements to include in the sum. See ~numpy.ufunc.reduce for details.

  • out (None)

Returns:

nansum – A new array holding the result is returned unless out is specified, in which it is returned. The result has the same size as a, and the same shape as a if axis is not None or a is a 1-d array.

Return type:

ndarray.