jax.numpy.union1dΒΆ

jax.numpy.union1d(ar1, ar2, *, size=None)[source]ΒΆ

Find the union of two arrays.

LAX-backend implementation of union1d().

Because the size of the output of union1d is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which specifies the size of the output array: it must be specified statically for jnp.union1d to be traced. If specified, the first size unique elements will be returned; if there are fewer unique elements than size indicates, the return value will be padded with the minimum value of the union.

Original docstring below.

Return the unique, sorted array of values that are in either of the two input arrays.

Parameters
  • ar1 (array_like) – Input arrays. They are flattened if they are not already 1D.

  • ar2 (array_like) – Input arrays. They are flattened if they are not already 1D.

Returns

union1d – Unique, sorted union of the input arrays.

Return type

ndarray