jax.numpy.union1d

Contents

jax.numpy.union1d#

jax.numpy.union1d(ar1, ar2, *, size=None, fill_value=None)[source]#

Find the union of two arrays.

LAX-backend implementation of numpy.union1d().

Because the size of the output of union1d is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which must be specified statically for jnp.union1d to be used within some of JAX’s transformations.

Original docstring below.

Return the unique, sorted array of values that are in either of the two input arrays.

Parameters:
  • ar1 (array_like) – Input arrays. They are flattened if they are not already 1D.

  • ar2 (array_like) – Input arrays. They are flattened if they are not already 1D.

  • size (int, optional) – If specified, the first size elements of the result will be returned. If there are fewer elements than size indicates, the return value will be padded with fill_value.

  • fill_value (array_like, optional) – When size is specified and there are fewer than the indicated number of elements, the remaining elements will be filled with fill_value, which defaults to the minimum value of the union.

Returns:

union1d – Unique, sorted union of the input arrays.

Return type:

ndarray