jax.numpy.intersect1d#
- jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False)[source]#
Find the intersection of two arrays.
LAX-backend implementation of
numpy.intersect1d()
.Original docstring below.
Return the sorted, unique values that are in both of the input arrays.
- Parameters:
ar1 (array_like) – Input arrays. Will be flattened if not already 1D.
ar2 (array_like) – Input arrays. Will be flattened if not already 1D.
assume_unique (bool) – If True, the input arrays are both assumed to be unique, which can speed up the calculation. If True but
ar1
orar2
are not unique, incorrect results and out-of-bounds indices could result. Default is False.return_indices (bool) – If True, the indices which correspond to the intersection of the two arrays are returned. The first instance of a value is used if there are multiple. Default is False.
- Return type:
- Returns:
intersect1d (ndarray) – Sorted 1D array of common and unique elements.
comm1 (ndarray) – The indices of the first occurrences of the common values in ar1. Only provided if return_indices is True.
comm2 (ndarray) – The indices of the first occurrences of the common values in ar2. Only provided if return_indices is True.