jax.numpy.setxor1d#
- jax.numpy.setxor1d(ar1, ar2, assume_unique=False)[source]#
Compute the set-wise xor of elements in two arrays.
JAX implementation of
numpy.setxor1d()
.Because the size of the output of
setxor1d
is data-dependent, the function is not compatible with JIT or other JAX transformations.- Parameters:
ar1 (jax.typing.ArrayLike) – first array of values to intersect.
ar2 (jax.typing.ArrayLike) – second array of values to intersect.
assume_unique (bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if
assume_unique
is True and the input arrays contain duplicates, the behavior is undefined. default: False.
- Returns:
An array of values that are found in exactly one of the input arrays.
- Return type:
See also
jax.numpy.intersect1d()
: the set intersection of two 1D arrays.jax.numpy.union1d()
: the set union of two 1D arrays.jax.numpy.setdiff1d()
: the set difference of two 1D arrays.
Examples
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.setxor1d(ar1, ar2) Array([1, 2, 5, 6], dtype=int32)