jax.numpy.clip#

jax.numpy.clip(a, a_min=None, a_max=None, out=None)[source]#

Clip (limit) the values in an array.

LAX-backend implementation of numpy.clip().

Original docstring below.

Given an interval, values outside the interval are clipped to the interval edges. For example, if an interval of [0, 1] is specified, values smaller than 0 become 0, and values larger than 1 become 1.

Equivalent to but faster than np.minimum(a_max, np.maximum(a, a_min)).

No check is performed to ensure a_min < a_max.

Parameters
  • a (array_like) – Array containing elements to clip.

  • a_min (array_like or None) – Minimum and maximum value. If None, clipping is not performed on the corresponding edge. Only one of a_min and a_max may be None. Both are broadcast against a.

  • a_max (array_like or None) – Minimum and maximum value. If None, clipping is not performed on the corresponding edge. Only one of a_min and a_max may be None. Both are broadcast against a.

Returns

clipped_array – An array with the elements of a, but where values < a_min are replaced with a_min, and those > a_max with a_max.

Return type

ndarray