jax.numpy.intersect1d#
- jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False, *, size=None, fill_value=None)[source]#
Compute the set intersection of two 1D arrays.
JAX implementation of
numpy.intersect1d()
.Because the size of the output of
intersect1d
is data-dependent, the function is not typically compatible withjit()
and other JAX transformations. The JAX version adds the optionalsize
argument which must be specified statically forjnp.intersect1d
to be used in such contexts.- Parameters:
ar1 (ArrayLike) – first array of values to intersect.
ar2 (ArrayLike) – second array of values to intersect.
assume_unique (bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if
assume_unique
is True and the input arrays contain duplicates, the behavior is undefined. default: False.return_indices (bool) – If True, return arrays of indices specifying where the intersected values first appear in the input arrays.
size (int | None | None) – if specified, return only the first
size
sorted elements. If there are fewer elements thansize
indicates, the return value will be padded withfill_value
, and returned indices will be padded with an out-of-bound index.fill_value (ArrayLike | None | None) – when
size
is specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value
. Defaults to the smallest value in the intersection.
- Returns:
An array
intersection
, or ifreturn_indices=True
, a tuple of arrays(intersection, ar1_indices, ar2_indices)
. Returned values areintersection
: A 1D array containing each value that appears in bothar1
andar2
.ar1_indices
: (returned if return_indices=True) an array of shapeintersection.shape
containing the indices in flattenedar1
of values inintersection
. For 1D inputs,intersection
is equivalent toar1[ar1_indices]
.ar2_indices
: (returned if return_indices=True) an array of shapeintersection.shape
containing the indices in flattenedar2
of values inintersection
. For 1D inputs,intersection
is equivalent toar2[ar2_indices]
.
- Return type:
See also
jax.numpy.union1d()
: the set union of two 1D arrays.jax.numpy.setxor1d()
: the set XOR of two 1D arrays.jax.numpy.setdiff1d()
: the set difference of two 1D arrays.
Examples
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.intersect1d(ar1, ar2) Array([3, 4], dtype=int32)
Computing intersection with indices:
>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True) >>> intersection Array([3, 4], dtype=int32)
ar1_indices
gives the indices of the intersected values withinar1
:>>> ar1_indices Array([2, 3], dtype=int32) >>> jnp.all(intersection == ar1[ar1_indices]) Array(True, dtype=bool)
ar2_indices
gives the indices of the intersected values withinar2
:>>> ar2_indices Array([0, 1], dtype=int32) >>> jnp.all(intersection == ar2[ar2_indices]) Array(True, dtype=bool)