jax.Array.sum#
- abstract Array.sum(axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)[source]#
Sum of the elements of the array over a given axis.
Refer to
jax.numpy.sum()
for full documentation.
Sum of the elements of the array over a given axis.
Refer to jax.numpy.sum()
for full documentation.