jax.numpy.union1d

Contents

jax.numpy.union1d#

jax.numpy.union1d(ar1, ar2, *, size=None, fill_value=None)[source]#

Compute the set union of two 1D arrays.

JAX implementation of numpy.union1d().

Because the size of the output of union1d is data-dependent, the function semantics are not typically compatible with jit() and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.union1d to be used in such contexts. transformations.

Parameters:
  • ar1 (ArrayLike) – first array of elements to be unioned.

  • ar2 (ArrayLike) – second array of elements to be unioned

  • size (int | None) – if specified, return only the first size sorted elements. If there are fewer elements than size indicates, the return value will be padded with fill_value.

  • fill_value (ArrayLike | None) – when size is specified and there are fewer than the indicated number of elements, fill the remaining entries fill_value. Defaults to the minimum value.

Returns:

an array containing the union of elements in the input array.

Return type:

Array

See also

Examples

Computing the union of two arrays:

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.union1d(ar1, ar2)
Array([1, 2, 3, 4, 5, 6], dtype=int32)

Because the output shape is dynamic, this will fail under jit() and other transformations:

>>> jax.jit(jnp.union1d)(ar1, ar2)  
Traceback (most recent call last):
   ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
The error occurred while tracing the function union1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.

In order to ensure statically-known output shapes, you can pass a static size argument:

>>> jit_union1d = jax.jit(jnp.union1d, static_argnames=['size'])
>>> jit_union1d(ar1, ar2, size=6)
Array([1, 2, 3, 4, 5, 6], dtype=int32)

If size is too small, the union is truncated:

>>> jit_union1d(ar1, ar2, size=4)
Array([1, 2, 3, 4], dtype=int32)

If size is too large, then the output is padded with fill_value:

>>> jit_union1d(ar1, ar2, size=8, fill_value=0)
Array([1, 2, 3, 4, 5, 6, 0, 0], dtype=int32)