jax.numpy.searchsorted

Contents

jax.numpy.searchsorted#

jax.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')[source]#

Perform a binary search within a sorted array.

JAX implementation of numpy.searchsorted().

This will return the indices within a sorted array a where values in v can be inserted to maintain its sort order.

Parameters:
  • a (ArrayLike) – one-dimensional array, assumed to be in sorted order unless sorter is specified.

  • v (ArrayLike) – N-dimensional array of query values

  • side (str) – 'left' (default) or 'right'; specifies whether insertion indices will be to the left or the right in case of ties.

  • sorter (ArrayLike | None) – optional array of indices specifying the sort order of a. If specified, then the algorithm assumes that a[sorter] is in sorted order.

  • method (str) – one of 'scan' (default), 'scan_unrolled', 'sort' or 'compare_all'. See Note below.

Returns:

Array of insertion indices of shape v.shape.

Return type:

Array

Note

The method argument controls the algorithm used to compute the insertion indices.

  • 'scan' (the default) tends to be more performant on CPU, particularly when a is very large.

  • 'scan_unrolled' is more performant on GPU at the expense of additional compile time.

  • 'sort' is often more performant on accelerator backends like GPU and TPU, particularly when v is very large.

  • 'compare_all' tends to be the most performant when a is very small.

Examples

Searching for a single value:

>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5])
>>> jnp.searchsorted(a, 2)
Array(1, dtype=int32)
>>> jnp.searchsorted(a, 2, side='right')
Array(3, dtype=int32)

Searching for a batch of values:

>>> vals = jnp.array([0, 3, 8, 1.5, 2])
>>> jnp.searchsorted(a, vals)
Array([0, 3, 7, 1, 1], dtype=int32)

Optionally, the sorter argument can be used to find insertion indices into an array sorted via jax.numpy.argsort():

>>> a = jnp.array([4, 3, 5, 1, 2])
>>> sorter = jnp.argsort(a)
>>> jnp.searchsorted(a, vals, sorter=sorter)
Array([0, 2, 5, 1, 1], dtype=int32)

The result is equivalent to passing the sorted array:

>>> jnp.searchsorted(jnp.sort(a), vals)
Array([0, 2, 5, 1, 1], dtype=int32)