jax.numpy.isin

Contents

jax.numpy.isin#

jax.numpy.isin(element, test_elements, assume_unique=False, invert=False)[source]#

Calculates element in test_elements, broadcasting over element only.

LAX-backend implementation of numpy.isin().

In the JAX version, the assume_unique argument is not referenced.

Original docstring below.

Returns a boolean array of the same shape as element that is True where an element of element is in test_elements and False otherwise.

Parameters:
  • element (array_like) – Input array.

  • test_elements (array_like) – The values against which to test each value of element. This argument is flattened if it is an array or array_like. See notes for behavior with non-array-like parameters.

  • assume_unique (bool, optional) – If True, the input arrays are both assumed to be unique, which can speed up the calculation. Default is False.

  • invert (bool, optional) – If True, the values in the returned array are inverted, as if calculating element not in test_elements. Default is False. np.isin(a, b, invert=True) is equivalent to (but faster than) np.invert(np.isin(a, b)).

Returns:

isin – Has the same shape as element. The values element[isin] are in test_elements.

Return type:

ndarray, bool