jax.scipy.stats.rankdata#
- jax.scipy.stats.rankdata(a, method='average', *, axis=None, nan_policy='propagate')[source]#
Compute the rank of data along an array axis.
JAX implementation of
scipy.stats.rankdata()
.Ranks begin at 1, and the method argument controls how ties are handled.
- Parameters:
a (ArrayLike) – arraylike
method (str) – str, default=”average”. Supported methods are
("average", "min", "max", "dense", "ordinal")
For details, see thescipy.stats.rankdata()
documentation.axis (int | None) – optional integer. If not specified, the input array is flattened.
nan_policy (str) – str, JAX’s implementation only supports
"propagate"
.
- Returns:
array of ranks along the specified axis.
- Return type:
Examples
>>> x = jnp.array([10, 30, 20]) >>> rankdata(x) Array([1., 3., 2.], dtype=float32)
>>> x = jnp.array([1, 3, 2, 3]) >>> rankdata(x) Array([1. , 3.5, 2. , 3.5], dtype=float32)