jax.numpy.union1d#
- jax.numpy.union1d(ar1, ar2, *, size=None, fill_value=None)[source]#
Find the union of two arrays.
LAX-backend implementation of
numpy.union1d()
.Because the size of the output of
union1d
is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optionalsize
argument which must be specified statically forjnp.union1d
to be used within some of JAX’s transformations.Original docstring below.
Return the unique, sorted array of values that are in either of the two input arrays.
- Parameters:
ar1 (array_like) – Input arrays. They are flattened if they are not already 1D.
ar2 (array_like) – Input arrays. They are flattened if they are not already 1D.
size (int, optional) – If specified, the first
size
elements of the result will be returned. If there are fewer elements thansize
indicates, the return value will be padded withfill_value
.fill_value (array_like, optional) – When
size
is specified and there are fewer than the indicated number of elements, the remaining elements will be filled withfill_value
, which defaults to the minimum value of the union.
- Returns:
union1d – Unique, sorted union of the input arrays.
- Return type:
ndarray