jax.numpy.argpartition

Contents

jax.numpy.argpartition#

jax.numpy.argpartition(a, kth, axis=-1)[source]#

Perform an indirect partition along the given axis using the

LAX-backend implementation of numpy.argpartition().

The JAX version requires the kth argument to be a static integer rather than a general array. This is implemented via two calls to jax.lax.top_k(). If you’re only accessing the top or bottom k values of the output, it may be more efficient to call jax.lax.top_k() directly.

The JAX version differs from the NumPy version in the treatment of NaN entries; NaNs which have the negative bit set are sorted to the beginning of the array.

Original docstring below.

algorithm specified by the kind keyword. It returns an array of indices of the same shape as a that index data along the given axis in partitioned order.

Added in version 1.8.0.

Parameters:
  • a (array_like) – Array to sort.

  • kth (int or sequence of ints) –

    Element index to partition by. The k-th element will be in its final sorted position and all smaller elements will be moved before it and all larger elements behind it. The order of all elements in the partitions is undefined. If provided with a sequence of k-th it will partition all of them into their sorted position at once.

    Deprecated since version 1.22.0: Passing booleans as index is deprecated.

  • axis (int or None, optional) – Axis along which to sort. The default is -1 (the last axis). If None, the flattened array is used.

Returns:

index_array – Array of indices that partition a along the specified axis. If a is one-dimensional, a[index_array] yields a partitioned a. More generally, np.take_along_axis(a, index_array, axis=axis) always yields the partitioned a, irrespective of dimensionality.

Return type:

ndarray, int