jax.numpy.union1dΒΆ

jax.numpy.union1d(ar1, ar2, *, size=None, fill_value=None)[source]ΒΆ

Find the union of two arrays.

LAX-backend implementation of union1d().

Because the size of the output of union1d is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which specifies the size of the output array: it must be specified statically for jnp.union1d to be compiled with non-static operands. If specified, the first size unique elements will be returned; if there are fewer unique elements than size indicates, the return value will be padded with fill_value, which defaults to the minimum value of the union.

Original docstring below.

Return the unique, sorted array of values that are in either of the two input arrays.

Parameters
  • ar1 (array_like) – Input arrays. They are flattened if they are not already 1D.

  • ar2 (array_like) – Input arrays. They are flattened if they are not already 1D.

Returns

union1d – Unique, sorted union of the input arrays.

Return type

ndarray