jax.numpy.setxor1d#

jax.numpy.setxor1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None)[source]#

Compute the set-wise xor of elements in two arrays.

JAX implementation of numpy.setxor1d().

Because the size of the output of setxor1d is data-dependent, the function is not compatible with JIT or other JAX transformations.

Parameters:
  • ar1 (ArrayLike) – first array of values to intersect.

  • ar2 (ArrayLike) – second array of values to intersect.

  • assume_unique (bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if assume_unique is True and the input arrays contain duplicates, the behavior is undefined. default: False.

  • size (int | None | None) – if specified, return only the first size sorted elements. If there are fewer elements than size indicates, the return value will be padded with fill_value, and returned indices will be padded with an out-of-bound index.

  • fill_value (ArrayLike | None | None) – when size is specified and there are fewer than the indicated number of elements, fill the remaining entries fill_value. Defaults to the smallest value in the xor result.

Returns:

An array of values that are found in exactly one of the input arrays.

Return type:

Array

See also

Examples

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.setxor1d(ar1, ar2)
Array([1, 2, 5, 6], dtype=int32)