jax.scipy.stats.mode

Contents

jax.scipy.stats.mode#

jax.scipy.stats.mode(a, axis=0, nan_policy='propagate', keepdims=False)[source]#

Compute the mode (most common value) along an axis of an array.

JAX implementation of scipy.stats.mode().

Parameters:
  • a (jax.typing.ArrayLike) – arraylike

  • axis (int | None) – int, default=0. Axis along which to compute the mode.

  • nan_policy (str) – str. JAX only supports "propagate".

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

Returns:

A tuple of arrays, (mode, count). mode is the array of modal values, and count is the number of times each value appears in the input array.

Return type:

ModeResult