jax.numpy.searchsorted

Contents

jax.numpy.searchsorted#

jax.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')[source]#

Find indices where elements should be inserted to maintain order.

LAX-backend implementation of numpy.searchsorted().

Original docstring below.

Find the indices into a sorted array a such that, if the corresponding elements in v were inserted before the indices, the order of a would be preserved.

Assuming that a is sorted:

side

returned index i satisfies

left

a[i-1] < v <= a[i]

right

a[i-1] <= v < a[i]

Parameters:
  • a (1-D array_like) – Input array. If sorter is None, then it must be sorted in ascending order, otherwise sorter must be an array of indices that sort it.

  • v (array_like) – Values to insert into a.

  • side ({'left', 'right'}, optional) – If ‘left’, the index of the first suitable location found is given. If ‘right’, return the last such index. If there is no suitable index, return either 0 or N (where N is the length of a).

  • method (str) – One of ‘scan’ (default), ‘scan_unrolled’, ‘sort’ or ‘compare_all’. Controls the method used by the implementation: ‘scan’ tends to be more performant on CPU (particularly when a is very large), ‘scan_unrolled’ is more performant on GPU at the expense of additional compile time, ‘sort’ is often more performant on accelerator backends like GPU and TPU (particularly when v is very large), and ‘compare_all’ can be most performant when a is very small.

  • sorter (None)

Returns:

indices – Array of insertion points with the same shape as v, or an integer if v is a scalar.

Return type:

int or array of ints