jax.numpy.setdiff1dΒΆ

jax.numpy.setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None)[source]ΒΆ

Find the set difference of two arrays.

LAX-backend implementation of setdiff1d().

Because the size of the output of setdiff1d is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which specifies the size of the output array: it must be specified statically for jnp.setdiff1d to be compiled with non-static operands. If specified, the first size unique elements will be returned; if there are fewer unique elements than size indicates, the return value will be padded with the fill_value, which defaults to zero.

Original docstring below.

Return the unique values in ar1 that are not in ar2.

Parameters
  • ar1 (array_like) – Input array.

  • ar2 (array_like) – Input comparison array.

  • assume_unique (bool) – If True, the input arrays are both assumed to be unique, which can speed up the calculation. Default is False.

Returns

setdiff1d – 1D array of values in ar1 that are not in ar2. The result is sorted when assume_unique=False, but otherwise only sorted if the input is sorted.

Return type

ndarray