jax.scipy.stats.rankdata

Contents

jax.scipy.stats.rankdata#

jax.scipy.stats.rankdata(a, method='average', *, axis=None, nan_policy='propagate')[source]#

Compute the rank of data along an array axis.

JAX implementation of scipy.stats.rankdata().

Ranks begin at 1, and the method argument controls how ties are handled.

Parameters:
  • a (jax.typing.ArrayLike) – arraylike

  • method (str) – str, default=”average”. Supported methods are ("average", "min", "max", "dense", "ordinal") For details, see the scipy.stats.rankdata() documentation.

  • axis (int | None) – optional integer. If not specified, the input array is flattened.

  • nan_policy (str) – str, JAX’s implementation only supports "propagate".

Returns:

array of ranks along the specified axis.

Return type:

Array

Examples

>>> x = jnp.array([10, 30, 20])
>>> rankdata(x)
Array([1., 3., 2.], dtype=float32)
>>> x = jnp.array([1, 3, 2, 3])
>>> rankdata(x)
Array([1. , 3.5, 2. , 3.5], dtype=float32)