jax.numpy.isin

Contents

jax.numpy.isin#

jax.numpy.isin(element, test_elements, assume_unique=False, invert=False)[source]#

Determine whether elements in element appear in test_elements.

JAX implementation of numpy.isin().

Parameters:
  • element (jax.typing.ArrayLike) – input array of elements for which membership will be checked.

  • test_elements (jax.typing.ArrayLike) – N-dimensional array of test values to check for the presence of each element.

  • invert (bool) – If True, return ~isin(element, test_elements). Default is False.

  • assume_unique (bool) – unused by JAX

Returns:

A boolean array of shape element.shape that specifies whether each element appears in test_elements.

Return type:

Array

Examples

>>> elements = jnp.array([1, 2, 3, 4])
>>> test_elements = jnp.array([[1, 5, 6, 3, 7, 1]])
>>> jnp.isin(elements, test_elements)
Array([ True, False,  True, False], dtype=bool)