jax.numpy.unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None)[source]ΒΆ

Find the unique elements of an array.

LAX-backend implementation of unique(). Original docstring below.

Returns the sorted unique elements of an array. There are three optional outputs in addition to the unique elements:

  • the indices of the input array that give the unique values

  • the indices of the unique array that reconstruct the input array

  • the number of times each unique value comes up in the input array

  • ar (array_like) – Input array. Unless axis is specified, this will be flattened if it is not already 1-D.

  • return_index (bool, optional) – If True, also return the indices of ar (along the specified axis, if provided, or in the flattened array) that result in the unique array.

  • return_inverse (bool, optional) – If True, also return the indices of the unique array (for the specified axis, if provided) that can be used to reconstruct ar.

  • return_counts (bool, optional) – If True, also return the number of times each unique item appears in ar.

  • axis (int or None, optional) – The axis to operate on. If None, ar will be flattened. If an integer, the subarrays indexed by the given axis will be flattened and treated as the elements of a 1-D array with the dimension of the given axis, see the notes for more details. Object arrays or structured arrays that contain objects are not supported if the axis kwarg is used. The default is None.


  • unique (ndarray) – The sorted unique values.

  • unique_indices (ndarray, optional) – The indices of the first occurrences of the unique values in the original array. Only provided if return_index is True.

  • unique_inverse (ndarray, optional) – The indices to reconstruct the original array from the unique array. Only provided if return_inverse is True.

  • unique_counts (ndarray, optional) – The number of times each of the unique values comes up in the original array. Only provided if return_counts is True.

    New in version 1.9.0.

See also


Module with a number of other functions for performing set operations on arrays.


When an axis is specified the subarrays indexed by the axis are sorted. This is done by making the specified axis the first dimension of the array (move the axis to the first dimension to keep the order of the other axes) and then flattening the subarrays in C order. The flattened subarrays are then viewed as a structured type with each element given a label, with the effect that we end up with a 1-D array of structured types that can be treated in the same way as any other 1-D array. The result is that the flattened subarrays are sorted in lexicographic order starting with the first element.


>>> np.unique([1, 1, 2, 2, 3, 3])
array([1, 2, 3])
>>> a = np.array([[1, 1], [2, 3]])
>>> np.unique(a)
array([1, 2, 3])

Return the unique rows of a 2D array

>>> a = np.array([[1, 0, 0], [1, 0, 0], [2, 3, 4]])
>>> np.unique(a, axis=0)
array([[1, 0, 0], [2, 3, 4]])

Return the indices of the original array that give the unique values:

>>> a = np.array(['a', 'b', 'b', 'c', 'a'])
>>> u, indices = np.unique(a, return_index=True)
>>> u
array(['a', 'b', 'c'], dtype='<U1')
>>> indices
array([0, 1, 3])
>>> a[indices]
array(['a', 'b', 'c'], dtype='<U1')

Reconstruct the input array from the unique values:

>>> a = np.array([1, 2, 6, 4, 2, 3, 2])
>>> u, indices = np.unique(a, return_inverse=True)
>>> u
array([1, 2, 3, 4, 6])
>>> indices
array([0, 1, 4, ..., 1, 2, 1])
>>> u[indices]
array([1, 2, 6, ..., 2, 3, 2])