jax.scipy.stats.mode#
- jax.scipy.stats.mode(a, axis=0, nan_policy='propagate', keepdims=False)[source]#
Compute the mode (most common value) along an axis of an array.
JAX implementation of
scipy.stats.mode()
.- Parameters:
- Returns:
A tuple of arrays,
(mode, count)
.mode
is the array of modal values, andcount
is the number of times each value appears in the input array.- Return type:
ModeResult