jax.numpy.digitize#
- jax.numpy.digitize(x, bins, right=False, *, method=None)[source]#
Convert an array to bin indices.
JAX implementation of
numpy.digitize()
.- Parameters:
x (ArrayLike) – array of values to digitize.
bins (ArrayLike) – 1D array of bin edges. Must be monotonically increasing or decreasing.
right (bool) – if true, the intervals include the right bin edges. If false (default) the intervals include the left bin edges.
method (str | None) – optional method argument to be passed to
searchsorted()
. See that function for available options.
- Returns:
An integer array of the same shape as
x
indicating the bin number that the values are in.- Return type:
See also
jax.numpy.searchsorted()
: find insertion indices for values in a sorted array.jax.numpy.histogram()
: compute frequency of array values within specified bins.
Examples
>>> x = jnp.array([1.0, 2.0, 2.5, 1.5, 3.0, 3.5]) >>> bins = jnp.array([1, 2, 3]) >>> jnp.digitize(x, bins) Array([1, 2, 2, 1, 3, 3], dtype=int32) >>> jnp.digitize(x, bins, right=True) Array([0, 1, 2, 1, 2, 3], dtype=int32)
digitize
supports reverse-ordered bins as well:>>> bins = jnp.array([3, 2, 1]) >>> jnp.digitize(x, bins) Array([2, 1, 1, 2, 0, 0], dtype=int32)