jax.Array.argpartition# abstract Array.argpartition(kth, axis=-1)[source]# Return the indices that partially sort the array. Refer to jax.numpy.argpartition() for the full documentation. Parameters: self (Array) kth (int) axis (int) Return type: Array