jax.numpy.clip

Contents

jax.numpy.clip#

jax.numpy.clip(arr=None, /, min=None, max=None, *, a=Deprecated, a_min=Deprecated, a_max=Deprecated)[source]#

Clip array values to a specified range.

JAX implementation of numpy.clip().

Parameters:
  • arr (ArrayLike | None) – N-dimensional array to be clipped.

  • min (ArrayLike | None) – optional minimum value of the clipped range; if None (default) then result will not be clipped to any minimum value. If specified, it should be broadcast-compatible with arr and max.

  • max (ArrayLike | None) – optional maximum value of the clipped range; if None (default) then result will not be clipped to any maximum value. If specified, it should be broadcast-compatible with arr and min.

  • a (ArrayLike | DeprecatedArg) – deprecated alias of the arr argument. Will result in a DeprecationWarning if used.

  • a_min (ArrayLike | None | DeprecatedArg) – deprecated alias of the min argument. Will result in a DeprecationWarning if used.

  • a_max (ArrayLike | None | DeprecatedArg) – deprecated alias of the max argument. Will result in a DeprecationWarning if used.

Returns:

An array containing values from arr, with values smaller than min set to min, and values larger than max set to max.

Return type:

Array

See also

Examples

>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7])
>>> jnp.clip(arr, 2, 5)
Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)