jax.Array.argpartition

jax.Array.argpartition#

abstract Array.argpartition(kth, axis=-1)[source]#

Return the indices that partially sort the array.

Refer to jax.numpy.argpartition() for the full documentation.

Parameters:
Return type:

Array