jax.numpy.intersect1d

Contents

jax.numpy.intersect1d#

jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False)[source]#

Compute the set intersection of two 1D arrays.

JAX implementation of numpy.intersect1d().

Because the size of the output of intersect1d is data-dependent, the function is not compatible with JIT or other JAX transformations.

Parameters:
  • ar1 (jax.typing.ArrayLike) – first array of values to intersect.

  • ar2 (jax.typing.ArrayLike) – second array of values to intersect.

  • assume_unique (bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if assume_unique is True and the input arrays contain duplicates, the behavior is undefined. default: False.

  • return_indices (bool) – If True, return arrays of indices specifying where the intersected values first appear in the input arrays.

Returns:

An array intersection, or if return_indices=True, a tuple of arrays (intersection, ar1_indices, ar2_indices). Returned values are

  • intersection: A 1D array containing each value that appears in both ar1 and ar2.

  • ar1_indices: (returned if return_indices=True) an array of shape intersection.shape containing the indices in flattened ar1 of values in intersection. For 1D inputs, intersection is equivalent to ar1[ar1_indices].

  • ar2_indices: (returned if return_indices=True) an array of shape intersection.shape containing the indices in flattened ar2 of values in intersection. For 1D inputs, intersection is equivalent to ar2[ar2_indices].

Return type:

Array | tuple[Array, Array, Array]

See also

Examples

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.intersect1d(ar1, ar2)
Array([3, 4], dtype=int32)

Computing intersection with indices:

>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True)
>>> intersection
Array([3, 4], dtype=int32)

ar1_indices gives the indices of the intersected values within ar1:

>>> ar1_indices
Array([2, 3], dtype=int32)
>>> jnp.all(intersection == ar1[ar1_indices])
Array(True, dtype=bool)

ar2_indices gives the indices of the intersected values within ar2:

>>> ar2_indices
Array([0, 1], dtype=int32)
>>> jnp.all(intersection == ar2[ar2_indices])
Array(True, dtype=bool)