jax.numpy.sumΒΆ

jax.numpy.sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None)[source]ΒΆ

Sum of array elements over a given axis.

LAX-backend implementation of sum().

Original docstring below.

Parameters
  • a (array_like) – Elements to sum.

  • axis (None or int or tuple of ints, optional) – Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.

  • dtype (dtype, optional) – The type of the returned array and of the accumulator in which the elements are summed. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.

  • keepdims (bool, optional) –

    If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.

    If the default value is passed, then keepdims will not be passed through to the sum method of sub-classes of ndarray, however any non-default value will be. If the sub-class’ method does not implement keepdims any exceptions will be raised.

  • initial (scalar, optional) – Starting value for the sum. See ~numpy.ufunc.reduce for details.

  • where (array_like of bool, optional) – Elements to include in the sum. See ~numpy.ufunc.reduce for details.

Returns

sum_along_axis – An array with the same shape as a, with the specified axis removed. If a is a 0-d array, or if axis is None, a scalar is returned. If an output array is specified, a reference to out is returned.

Return type

ndarray