jax.numpy.unique#
- jax.numpy.unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None, *, equal_nan=True, size=None, fill_value=None)[source]#
Return the unique values from an array.
JAX implementation of
numpy.unique()
.Because the size of the output of
unique
is data-dependent, the function is not typically compatible withjit()
and other JAX transformations. The JAX version adds the optionalsize
argument which must be specified statically forjnp.unique
to be used in such contexts.- Parameters:
ar (ArrayLike) β N-dimensional array from which unique values will be extracted.
return_index (bool) β if True, also return the indices in
ar
where each value occursreturn_inverse (bool) β if True, also return the indices that can be used to reconstruct
ar
from the unique values.return_counts (bool) β if True, also return the number of occurrences of each unique value.
axis (int | None | None) β if specified, compute unique values along the specified axis. If None (default), then flatten
ar
before computing the unique values.equal_nan (bool) β if True, consider NaN values equivalent when determining uniqueness.
size (int | None | None) β if specified, return only the first
size
sorted unique elements. If there are fewer unique elements thansize
indicates, the return value will be padded withfill_value
.fill_value (ArrayLike | None | None) β when
size
is specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value
. Defaults to the minimum unique value.
- Returns:
An array or tuple of arrays, depending on the values of
return_index
,return_inverse
, andreturn_counts
. Returned values areunique_values
:if
axis
is None, a 1D array of lengthn_unique
, Ifaxis
is specified, shape is(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])
.
unique_index
:(returned only if return_index is True) An array of shape
(n_unique,)
. Contains the indices of the first occurrence of each unique value inar
. For 1D inputs,ar[unique_index]
is equivalent tounique_values
.
unique_inverse
:(returned only if return_inverse is True) An array of shape
(ar.size,)
ifaxis
is None, or of shape(ar.shape[axis],)
ifaxis
is specified. Contains the indices withinunique_values
of each value inar
. For 1D inputs,unique_values[unique_inverse]
is equivalent toar
.
unique_counts
:(returned only if return_counts is True) An array of shape
(n_unique,)
. Contains the number of occurrences of each unique value inar
.
See also
jax.numpy.unique_counts()
: shortcut tounique(arr, return_counts=True)
.jax.numpy.unique_inverse()
: shortcut tounique(arr, return_inverse=True)
.jax.numpy.unique_all()
: shortcut tounique
with all return values.jax.numpy.unique_values()
: likeunique
, but no optional return values.
Examples
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> jnp.unique(x) Array([1, 3, 4], dtype=int32)
JIT compilation & the size argument
If you try this under
jit()
or another transformation, you will get an error because the output shape is dynamic:>>> jax.jit(jnp.unique)(x) Traceback (most recent call last): ... jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[5]. The error arose for the first argument of jnp.unique(). To make jnp.unique() compatible with JIT and other transforms, you can specify a concrete value for the size argument, which will determine the output size.
The issue is that the output of transformed functions must have static shapes. In order to make this work, you can pass a static
size
parameter:>>> jit_unique = jax.jit(jnp.unique, static_argnames=['size']) >>> jit_unique(x, size=3) Array([1, 3, 4], dtype=int32)
If your static size is smaller than the true number of unique values, they will be truncated.
>>> jit_unique(x, size=2) Array([1, 3], dtype=int32)
If the static size is larger than the true number of unique values, they will be padded with
fill_value
, which defaults to the minimum unique value:>>> jit_unique(x, size=5) Array([1, 3, 4, 1, 1], dtype=int32) >>> jit_unique(x, size=5, fill_value=0) Array([1, 3, 4, 0, 0], dtype=int32)
Multi-dimensional unique values
If you pass a multi-dimensional array to
unique
, it will be flattened by default:>>> M = jnp.array([[1, 2], ... [2, 3], ... [1, 2]]) >>> jnp.unique(M) Array([1, 2, 3], dtype=int32)
If you pass an
axis
keyword, you can find unique slices of the array along that axis:>>> jnp.unique(M, axis=0) Array([[1, 2], [2, 3]], dtype=int32)
Returning indices
If you set
return_index=True
, thenunique
returns the indices of the first occurrence of each unique value:>>> x = jnp.array([3, 4, 1, 3, 1]) >>> values, indices = jnp.unique(x, return_index=True) >>> print(values) [1 3 4] >>> print(indices) [2 0 1] >>> jnp.all(values == x[indices]) Array(True, dtype=bool)
In multiple dimensions, the unique values can be extracted with
jax.numpy.take()
evaluated along the specified axis:>>> values, indices = jnp.unique(M, axis=0, return_index=True) >>> jnp.all(values == jnp.take(M, indices, axis=0)) Array(True, dtype=bool)
Returning inverse
If you set
return_inverse=True
, thenunique
returns the indices within the unique values for every entry in the input array:>>> x = jnp.array([3, 4, 1, 3, 1]) >>> values, inverse = jnp.unique(x, return_inverse=True) >>> print(values) [1 3 4] >>> print(inverse) [1 2 0 1 0] >>> jnp.all(values[inverse] == x) Array(True, dtype=bool)
In multiple dimensions, the input can be reconstructed using
jax.numpy.take()
:>>> values, inverse = jnp.unique(M, axis=0, return_inverse=True) >>> jnp.all(jnp.take(values, inverse, axis=0) == M) Array(True, dtype=bool)
Returning counts
If you set
return_counts=True
, thenunique
returns the number of occurrences within the input for every unique value:>>> x = jnp.array([3, 4, 1, 3, 1]) >>> values, counts = jnp.unique(x, return_counts=True) >>> print(values) [1 3 4] >>> print(counts) [2 2 1]
For multi-dimensional arrays, this also returns a 1D array of counts indicating number of occurrences along the specified axis:
>>> values, counts = jnp.unique(M, axis=0, return_counts=True) >>> print(values) [[1 2] [2 3]] >>> print(counts) [2 1]