# jax.numpy.setdiff1d#

jax.numpy.setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None)[source]#

Compute the set difference of two 1D arrays.

JAX implementation of `numpy.setdiff1d()`.

Because the size of the output of `setdiff1d` is data-dependent, the function semantics are not typically compatible with `jit()` and other JAX transformations. The JAX version adds the optional `size` argument which must be specified statically for `jnp.setdiff1d` to be used in such contexts. transformations.

Parameters:
• ar1 (ArrayLike) â€“ first array of elements to be differenced.

• ar2 (ArrayLike) â€“ second array of elements to be differenced.

• assume_unique (bool) â€“ if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if `assume_unique` is True and the input arrays contain duplicates, the behavior is undefined. default: False.

• size (int | None) â€“ if specified, return only the first `size` sorted elements. If there are fewer elements than `size` indicates, the return value will be padded with `fill_value`.

• fill_value (ArrayLike | None) â€“ when `size` is specified and there are fewer than the indicated number of elements, fill the remaining entries `fill_value`. Defaults to the minimum value.

Returns:

i.e. the elements in `ar1` that are not contained in `ar2`.

Return type:

an array containing the set difference of elements in the input array

Examples

Computing the set difference of two arrays:

```>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.setdiff1d(ar1, ar2)
Array([1, 2], dtype=int32)
```

Because the output shape is dynamic, this will fail under `jit()` and other transformations:

```>>> jax.jit(jnp.setdiff1d)(ar1, ar2)
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
```

In order to ensure statically-known output shapes, you can pass a static `size` argument:

```>>> jit_setdiff1d = jax.jit(jnp.setdiff1d, static_argnames=['size'])
>>> jit_setdiff1d(ar1, ar2, size=2)
Array([1, 2], dtype=int32)
```

If `size` is too small, the difference is truncated:

```>>> jit_setdiff1d(ar1, ar2, size=1)
Array([1], dtype=int32)
```

If `size` is too large, then the output is padded with `fill_value`:

```>>> jit_setdiff1d(ar1, ar2, size=4, fill_value=0)
Array([1, 2, 0, 0], dtype=int32)
```