jax.numpy.unique_counts#
- jax.numpy.unique_counts(x, /, *, size=None, fill_value=None)[source]#
Return unique values from x, along with counts.
JAX implementation of
numpy.unique_counts()
; this is equivalent to callingjax.numpy.unique()
with return_counts and equal_nan set to True.Because the size of the output of
unique_counts
is data-dependent, the function semantics are not typically compatible withjit()
and other JAX transformations. The JAX version adds the optionalsize
argument which must be specified statically forjnp.unique
to be used in such contexts.- Parameters:
x (ArrayLike) – N-dimensional array from which unique values will be extracted.
size (int | None) – if specified, return only the first
size
sorted unique elements. If there are fewer unique elements thansize
indicates, the return value will be padded withfill_value
.fill_value (ArrayLike | None) – when
size
is specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value
. Defaults to the minimum unique value.
- Returns:
values
:an array of shape
(n_unique,)
containing the unique values fromx
.
counts
:An array of shape
(n_unique,)
. Contains the number of occurances of each unique value inx
.
- Return type:
A tuple
(values, counts)
, with the following properties
See also
jax.numpy.unique()
: general function for computing unique values.jax.numpy.unique_values()
: compute onlyvalues
.jax.numpy.unique_inverse()
: compute onlyvalues
andinverse
.jax.numpy.unique_all()
: computevalues
,indices
,inverse_indices
, andcounts
.
Examples
Here we compute the unique values in a 1D array:
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> result = jnp.unique_counts(x)
The result is a
NamedTuple
with two named attributes. Thevalues
attribue contains the unique values from the array:>>> result.values Array([1, 3, 4], dtype=int32)
The
counts
attribute contains the counts of each unique value in the input:>>> result.counts Array([2, 2, 1], dtype=int32)
For examples of the
size
andfill_value
arguments, seejax.numpy.unique()
.