jax.numpy.digitize

Contents

jax.numpy.digitize#

jax.numpy.digitize(x, bins, right=False)[source]#

Return the indices of the bins to which each value in input array belongs.

LAX-backend implementation of numpy.digitize().

Original docstring below.

right

order of bins

returned index i satisfies

False

increasing

bins[i-1] <= x < bins[i]

True

increasing

bins[i-1] < x <= bins[i]

False

decreasing

bins[i-1] > x >= bins[i]

True

decreasing

bins[i-1] >= x > bins[i]

If values in x are beyond the bounds of bins, 0 or len(bins) is returned as appropriate.

Parameters:
  • x (array_like) – Input array to be binned. Prior to NumPy 1.10.0, this array had to be 1-dimensional, but can now have any shape.

  • bins (array_like) – Array of bins. It has to be 1-dimensional and monotonic.

  • right (bool, optional) – Indicating whether the intervals include the right or the left bin edge. Default behavior is (right==False) indicating that the interval does not include the right edge. The left bin end is open in this case, i.e., bins[i-1] <= x < bins[i] is the default behavior for monotonically increasing bins.

Returns:

indices – Output array of indices, of same shape as x.

Return type:

ndarray of ints