jax.numpy.partition

Contents

jax.numpy.partition#

jax.numpy.partition(a, kth, axis=-1)[source]#

Return a partitioned copy of an array.

LAX-backend implementation of numpy.partition().

The JAX version requires the kth argument to be a static integer rather than a general array. This is implemented via two calls to jax.lax.top_k(). If you’re only accessing the top or bottom k values of the output, it may be more efficient to call jax.lax.top_k() directly.

The JAX version differs from the NumPy version in the treatment of NaN entries; NaNs which have the negative bit set are sorted to the beginning of the array.

Original docstring below.

Creates a copy of the array with its elements rearranged in such a way that the value of the element in k-th position is in the position the value would be in a sorted array. In the partitioned array, all elements before the k-th element are less than or equal to that element, and all the elements after the k-th element are greater than or equal to that element. The ordering of the elements in the two partitions is undefined.

Added in version 1.8.0.

Parameters:
  • a (array_like) – Array to be sorted.

  • kth (int or sequence of ints) –

    Element index to partition by. The k-th value of the element will be in its final sorted position and all smaller elements will be moved before it and all equal or greater elements behind it. The order of all elements in the partitions is undefined. If provided with a sequence of k-th it will partition all elements indexed by k-th of them into their sorted position at once.

    Deprecated since version 1.22.0: Passing booleans as index is deprecated.

  • axis (int or None, optional) – Axis along which to sort. If None, the array is flattened before sorting. The default is -1, which sorts along the last axis.

Returns:

partitioned_array – Array of the same type and shape as a.

Return type:

ndarray