# jax.numpy.intersect1d¶

jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False)[source]

Find the intersection of two arrays.

LAX-backend implementation of intersect1d(). Original docstring below.

Return the sorted, unique values that are in both of the input arrays.

Parameters
• ar2 (ar1,) – Input arrays. Will be flattened if not already 1D.

• assume_unique (bool) – If True, the input arrays are both assumed to be unique, which can speed up the calculation. Default is False.

• return_indices (bool) – If True, the indices which correspond to the intersection of the two arrays are returned. The first instance of a value is used if there are multiple. Default is False.

Returns

• intersect1d (ndarray) – Sorted 1D array of common and unique elements.

• comm1 (ndarray) – The indices of the first occurrences of the common values in ar1. Only provided if return_indices is True.

• comm2 (ndarray) – The indices of the first occurrences of the common values in ar2. Only provided if return_indices is True.

numpy.lib.arraysetops()

Module with a number of other functions for performing set operations on arrays.

Examples

>>> np.intersect1d([1, 3, 4, 3], [3, 1, 2, 1])
array([1, 3])


To intersect more than two arrays, use functools.reduce:

>>> from functools import reduce
>>> reduce(np.intersect1d, ([1, 3, 4, 3], [3, 1, 2, 1], [6, 3, 4, 2]))
array()


To return the indices of the values common to the input arrays along with the intersected values:

>>> x = np.array([1, 1, 2, 3, 4])
>>> y = np.array([2, 1, 4, 6])
>>> xy, x_ind, y_ind = np.intersect1d(x, y, return_indices=True)
>>> x_ind, y_ind
(array([0, 2, 4]), array([1, 0, 2]))
>>> xy, x[x_ind], y[y_ind]
(array([1, 2, 4]), array([1, 2, 4]), array([1, 2, 4]))