jax.numpy.intersect1d#

jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False, *, size=None, fill_value=None)[source]#

Compute the set intersection of two 1D arrays.

JAX implementation of numpy.intersect1d().

Because the size of the output of intersect1d is data-dependent, the function is not typically compatible with jit() and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.intersect1d to be used in such contexts.

Parameters:
  • ar1 (ArrayLike) – first array of values to intersect.

  • ar2 (ArrayLike) – second array of values to intersect.

  • assume_unique (bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if assume_unique is True and the input arrays contain duplicates, the behavior is undefined. default: False.

  • return_indices (bool) – If True, return arrays of indices specifying where the intersected values first appear in the input arrays.

  • size (int | None | None) – if specified, return only the first size sorted elements. If there are fewer elements than size indicates, the return value will be padded with fill_value, and returned indices will be padded with an out-of-bound index.

  • fill_value (ArrayLike | None | None) – when size is specified and there are fewer than the indicated number of elements, fill the remaining entries fill_value. Defaults to the smallest value in the intersection.

Returns:

An array intersection, or if return_indices=True, a tuple of arrays (intersection, ar1_indices, ar2_indices). Returned values are

  • intersection: A 1D array containing each value that appears in both ar1 and ar2.

  • ar1_indices: (returned if return_indices=True) an array of shape intersection.shape containing the indices in flattened ar1 of values in intersection. For 1D inputs, intersection is equivalent to ar1[ar1_indices].

  • ar2_indices: (returned if return_indices=True) an array of shape intersection.shape containing the indices in flattened ar2 of values in intersection. For 1D inputs, intersection is equivalent to ar2[ar2_indices].

Return type:

Array | tuple[Array, Array, Array]

See also

Examples

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.intersect1d(ar1, ar2)
Array([3, 4], dtype=int32)

Computing intersection with indices:

>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True)
>>> intersection
Array([3, 4], dtype=int32)

ar1_indices gives the indices of the intersected values within ar1:

>>> ar1_indices
Array([2, 3], dtype=int32)
>>> jnp.all(intersection == ar1[ar1_indices])
Array(True, dtype=bool)

ar2_indices gives the indices of the intersected values within ar2:

>>> ar2_indices
Array([0, 1], dtype=int32)
>>> jnp.all(intersection == ar2[ar2_indices])
Array(True, dtype=bool)