jax.scipy.stats.mode

Contents

jax.scipy.stats.mode#

jax.scipy.stats.mode(a, axis=0, nan_policy='propagate', keepdims=False)[source]#

LAX-backend implementation of scipy.stats._stats_py.mode().

Currently the only supported nan_policy is ‘propagate’

Original docstring below.

Return an array of the modal (most common) value in the passed array.

If there is more than one such value, only one is returned. The bin-count for the modal bins is also returned.

Parameters:
  • a (array_like) – Numeric, n-dimensional array of which to find mode(s).

  • axis (int or None, default: 0) – If an int, the axis of the input along which to compute the statistic. The statistic of each axis-slice (e.g. row) of the input will appear in a corresponding element of the output. If None, the input will be raveled before computing the statistic.

  • nan_policy ({'propagate', 'omit', 'raise'}) –

    Defines how to handle input NaNs.

    • propagate: if a NaN is present in the axis slice (e.g. row) along which the statistic is computed, the corresponding entry of the output will be NaN.

    • omit: NaNs will be omitted when performing the calculation. If insufficient data remains in the axis slice along which the statistic is computed, the corresponding entry of the output will be NaN.

    • raise: if a NaN is present, a ValueError will be raised.

  • keepdims (bool, default: False) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.

Return type:

ModeResult

Returns:

  • mode (ndarray) – Array of modal values.

  • count (ndarray) – Array of counts for each mode.