jax.numpy.isin#
- jax.numpy.isin(element, test_elements, assume_unique=False, invert=False, *, method='auto')[source]#
Determine whether elements in
element
appear intest_elements
.JAX implementation of
numpy.isin()
.- Parameters:
element (ArrayLike) – input array of elements for which membership will be checked.
test_elements (ArrayLike) – N-dimensional array of test values to check for the presence of each element.
invert (bool) – If True, return
~isin(element, test_elements)
. Default is False.assume_unique (bool) – if true, input arrays are assumed to be unique, which can lead to more efficient computation. If the input arrays are not unique and assume_unique is set to True, the results are undefined.
method – string specifying the method used to compute the result. Supported options are ‘compare_all’, ‘binary_search’, ‘sort’, and ‘auto’ (default).
- Returns:
A boolean array of shape
element.shape
that specifies whether each element appears intest_elements
.- Return type:
Examples
>>> elements = jnp.array([1, 2, 3, 4]) >>> test_elements = jnp.array([[1, 5, 6, 3, 7, 1]]) >>> jnp.isin(elements, test_elements) Array([ True, False, True, False], dtype=bool)