jax.numpy.isin#
- jax.numpy.isin(element, test_elements, assume_unique=False, invert=False)[source]#
Determine whether elements in
element
appear intest_elements
.JAX implementation of
numpy.isin()
.- Parameters:
element (jax.typing.ArrayLike) – input array of elements for which membership will be checked.
test_elements (jax.typing.ArrayLike) – N-dimensional array of test values to check for the presence of each element.
invert (bool) – If True, return
~isin(element, test_elements)
. Default is False.assume_unique (bool) – unused by JAX
- Returns:
A boolean array of shape
element.shape
that specifies whether each element appears intest_elements
.- Return type:
Examples
>>> elements = jnp.array([1, 2, 3, 4]) >>> test_elements = jnp.array([[1, 5, 6, 3, 7, 1]]) >>> jnp.isin(elements, test_elements) Array([ True, False, True, False], dtype=bool)