jax.numpy.argpartition#
- jax.numpy.argpartition(a, kth, axis=-1)[source]#
Perform an indirect partition along the given axis using the
LAX-backend implementation of
numpy.argpartition()
.The JAX version requires the
kth
argument to be a static integer rather than a general array. This is implemented via two calls tojax.lax.top_k()
. If you’re only accessing the top or bottom k values of the output, it may be more efficient to calljax.lax.top_k()
directly.The JAX version differs from the NumPy version in the treatment of NaN entries; NaNs which have the negative bit set are sorted to the beginning of the array.
Original docstring below.
algorithm specified by the kind keyword. It returns an array of indices of the same shape as a that index data along the given axis in partitioned order.
New in version 1.8.0.
- Parameters:
a (array_like) – Array to sort.
kth (int or sequence of ints) –
Element index to partition by. The k-th element will be in its final sorted position and all smaller elements will be moved before it and all larger elements behind it. The order of all elements in the partitions is undefined. If provided with a sequence of k-th it will partition all of them into their sorted position at once.
Deprecated since version 1.22.0: Passing booleans as index is deprecated.
axis (int or None, optional) – Axis along which to sort. The default is -1 (the last axis). If None, the flattened array is used.
- Returns:
index_array – Array of indices that partition a along the specified axis. If a is one-dimensional,
a[index_array]
yields a partitioned a. More generally,np.take_along_axis(a, index_array, axis=axis)
always yields the partitioned a, irrespective of dimensionality.- Return type:
ndarray, int