jax.Array.mean#
- abstract Array.mean(axis=None, dtype=None, out=None, keepdims=False, *, where=None)[source]#
Return the mean of array elements along a given axis.
Refer to
jax.numpy.mean()
for the full documentation.
Return the mean of array elements along a given axis.
Refer to jax.numpy.mean()
for the full documentation.