jax.numpy.setdiff1d#
- jax.numpy.setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None)[source]#
Find the set difference of two arrays.
LAX-backend implementation of
numpy.setdiff1d()
.Because the size of the output of
setdiff1d
is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optionalsize
argument which must be specified statically forjnp.setdiff1d
to be used within some of JAX’s transformations.Original docstring below.
Return the unique values in ar1 that are not in ar2.
- Parameters
ar1 (array_like) – Input array.
ar2 (array_like) – Input comparison array.
assume_unique (bool) – If True, the input arrays are both assumed to be unique, which can speed up the calculation. Default is False.
size (int, optional) – If specified, the first
size
elements of the result will be returned. If there are fewer elements thansize
indicates, the return value will be padded withfill_value
.fill_value (array_like, optional) – When
size
is specified and there are fewer than the indicated number of elements, the remaining elements will be filled withfill_value
, which defaults to zero.
- Returns
setdiff1d – 1D array of values in ar1 that are not in ar2. The result is sorted when assume_unique=False, but otherwise only sorted if the input is sorted.
- Return type
ndarray