jax.numpy.bincount#

jax.numpy.bincount(x, weights=None, minlength=0, *, length=None)[source]#

Count the number of occurrences of each value in an integer array.

JAX implementation of `numpy.bincount()`.

For an array of positive integers `x`, this function returns an array `counts` of size `x.max() + 1`, such that `counts[i]` contains the number of occurrences of the value `i` in `x`.

The JAX version has a few differences from the NumPy version:

• In NumPy, passing an array `x` with negative entries will result in an error. In JAX, negative values are clipped to zero.

• JAX adds an optional `length` parameter which can be used to statically specify the length of the output array so that this function can be used with transformations like `jax.jit()`. In this case, items larger than length + 1 will be dropped.

Parameters:
• x (ArrayLike) â€“ N-dimensional array of positive integers

• weights (ArrayLike | None) â€“ optional array of weights associated with `x`. If not specified, the weight for each entry will be `1`.

• minlength (int) â€“ the minimum length of the output counts array.

• length (int | None) â€“ the length of the output counts array. Must be specified statically for `bincount` to be used with `jax.jit()` and other JAX transformations.

Returns:

An array of counts or summed weights reflecting the number of occurrences of values in `x`.

Return type:

Array

Examples

Basic bincount:

```>>> x = jnp.array([1, 1, 2, 3, 3, 3])
>>> jnp.bincount(x)
Array([0, 2, 1, 3], dtype=int32)
```

Weighted bincount:

```>>> weights = jnp.array([1, 2, 3, 4, 5, 6])
>>> jnp.bincount(x, weights)
Array([ 0,  3,  3, 15], dtype=int32)
```

Specifying a static `length` makes this jit-compatible:

```>>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length'])
>>> jit_bincount(x, length=5)
Array([0, 2, 1, 3, 0], dtype=int32)
```

Any negative numbers are clipped to the first bin, and numbers beyond the specified `length` are dropped:

```>>> x = jnp.array([-1, -1, 1, 3, 10])
>>> jnp.bincount(x, length=5)
Array([2, 1, 0, 1, 0], dtype=int32)
```