jax.Array.std#
- abstract Array.std(axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, correction=None)[source]#
Compute the standard deviation along a given axis.
Refer to
jax.numpy.std()
for full documentation.
Compute the standard deviation along a given axis.
Refer to jax.numpy.std()
for full documentation.