jax.numpy.unique_all#
- jax.numpy.unique_all(x, /, *, size=None, fill_value=None)[source]#
Return unique values from x, along with indices, inverse indices, and counts.
JAX implementation of
numpy.unique_all()
; this is equivalent to callingjax.numpy.unique()
with return_index, return_inverse, return_counts, and equal_nan set to True.Because the size of the output of
unique_all
is data-dependent, the function semantics are not typically compatible withjit()
and other JAX transformations. The JAX version adds the optionalsize
argument which must be specified statically forjnp.unique
to be used in such contexts.- Parameters:
x (ArrayLike) – N-dimensional array from which unique values will be extracted.
size (int | None) – if specified, return only the first
size
sorted unique elements. If there are fewer unique elements thansize
indicates, the return value will be padded withfill_value
.fill_value (ArrayLike | None) – when
size
is specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value
. Defaults to the minimum unique value.
- Returns:
values
:an array of shape
(n_unique,)
containing the unique values fromx
.
indices
:An array of shape
(n_unique,)
. Contains the indices of the first occurance of each unique value inx
. For 1D inputs,x[indices]
is equivlent tovalues
.
inverse_indices
:An array of shape
x.shape
. Contains the indices withinvalues
of each value inx
. For 1D inputs,values[inverse_indices]
is equivalent tox
.
counts
:An array of shape
(n_unique,)
. Contains the number of occurances of each unique value inx
.
- Return type:
A tuple
(values, indices, inverse_indices, counts)
, with the following properties
See also
jax.numpy.unique()
: general function for computing unique values.jax.numpy.unique_values()
: compute onlyvalues
.jax.numpy.unique_counts()
: compute onlyvalues
andcounts
.jax.numpy.unique_inverse()
: compute onlyvalues
andinverse
.
Examples
Here we compute the unique values in a 1D array:
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> result = jnp.unique_all(x)
The result is a
NamedTuple
with four named attributes. Thevalues
attribue contains the unique values from the array:>>> result.values Array([1, 3, 4], dtype=int32)
The
indices
attribute contains the indices of the uniquevalues
within the input array:>>> result.indices Array([2, 0, 1], dtype=int32) >>> jnp.all(result.values == x[result.indices]) Array(True, dtype=bool)
The
inverse_indices
attribute contains the indices of the input withinvalues
:>>> result.inverse_indices Array([1, 2, 0, 1, 0], dtype=int32) >>> jnp.all(x == result.values[result.inverse_indices]) Array(True, dtype=bool)
The
counts
attribute contains the counts of each unique value in the input:>>> result.counts Array([2, 2, 1], dtype=int32)
For examples of the
size
andfill_value
arguments, seejax.numpy.unique()
.