jax.numpy.setdiff1d#
- jax.numpy.setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None)[source]#
Compute the set difference of two 1D arrays.
JAX implementation of
numpy.setdiff1d()
.Because the size of the output of
setdiff1d
is data-dependent, the function semantics are not typically compatible withjit()
and other JAX transformations. The JAX version adds the optionalsize
argument which must be specified statically forjnp.setdiff1d
to be used in such contexts. transformations.- Parameters:
ar1 (ArrayLike) – first array of elements to be differenced.
ar2 (ArrayLike) – second array of elements to be differenced.
assume_unique (bool) – if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if
assume_unique
is True and the input arrays contain duplicates, the behavior is undefined. default: False.size (int | None) – if specified, return only the first
size
sorted elements. If there are fewer elements thansize
indicates, the return value will be padded withfill_value
.fill_value (ArrayLike | None) – when
size
is specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value
. Defaults to the minimum value.
- Returns:
i.e. the elements in
ar1
that are not contained inar2
.- Return type:
an array containing the set difference of elements in the input array
See also
jax.numpy.intersect1d()
: the set intersection of two 1D arrays.jax.numpy.setxor1d()
: the set XOR of two 1D arrays.jax.numpy.union1d()
: the set union of two 1D arrays.
Examples
Computing the set difference of two arrays:
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.setdiff1d(ar1, ar2) Array([1, 2], dtype=int32)
Because the output shape is dynamic, this will fail under
jit()
and other transformations:>>> jax.jit(jnp.setdiff1d)(ar1, ar2) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
In order to ensure statically-known output shapes, you can pass a static
size
argument:>>> jit_setdiff1d = jax.jit(jnp.setdiff1d, static_argnames=['size']) >>> jit_setdiff1d(ar1, ar2, size=2) Array([1, 2], dtype=int32)
If
size
is too small, the difference is truncated:>>> jit_setdiff1d(ar1, ar2, size=1) Array([1], dtype=int32)
If
size
is too large, then the output is padded withfill_value
:>>> jit_setdiff1d(ar1, ar2, size=4, fill_value=0) Array([1, 2, 0, 0], dtype=int32)