jax.numpy.setdiff1d

Contents

jax.numpy.setdiff1d#

jax.numpy.setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None)[source]#

Find the set difference of two arrays.

LAX-backend implementation of numpy.setdiff1d().

Because the size of the output of setdiff1d is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which must be specified statically for jnp.setdiff1d to be used within some of JAX’s transformations.

Original docstring below.

Return the unique values in ar1 that are not in ar2.

Parameters:
  • ar1 (array_like) – Input array.

  • ar2 (array_like) – Input comparison array.

  • assume_unique (bool) – If True, the input arrays are both assumed to be unique, which can speed up the calculation. Default is False.

  • size (int, optional) – If specified, the first size elements of the result will be returned. If there are fewer elements than size indicates, the return value will be padded with fill_value.

  • fill_value (array_like, optional) – When size is specified and there are fewer than the indicated number of elements, the remaining elements will be filled with fill_value, which defaults to zero.

Returns:

setdiff1d – 1D array of values in ar1 that are not in ar2. The result is sorted when assume_unique=False, but otherwise only sorted if the input is sorted.

Return type:

ndarray