jax.numpy.setxor1d

Contents

jax.numpy.setxor1d#

jax.numpy.setxor1d(ar1, ar2, assume_unique=False)[source]#

Compute the set-wise xor of elements in two arrays.

JAX implementation of numpy.setxor1d().

Because the size of the output of setxor1d is data-dependent, the function is not compatible with JIT or other JAX transformations.

Parameters:
  • ar1 (jax.typing.ArrayLike) – first array of values to intersect.

  • ar2 (jax.typing.ArrayLike) – second array of values to intersect.

  • assume_unique (bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if assume_unique is True and the input arrays contain duplicates, the behavior is undefined. default: False.

Returns:

An array of values that are found in exactly one of the input arrays.

Return type:

Array

See also

Examples

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.setxor1d(ar1, ar2)
Array([1, 2, 5, 6], dtype=int32)