jax.numpy.clip#
- jax.numpy.clip(arr=None, /, min=None, max=None, *, a=Deprecated, a_min=Deprecated, a_max=Deprecated)[source]#
Clip array values to a specified range.
JAX implementation of
numpy.clip()
.- Parameters:
arr (ArrayLike | None) – N-dimensional array to be clipped.
min (ArrayLike | None) – optional minimum value of the clipped range; if
None
(default) then result will not be clipped to any minimum value. If specified, it should be broadcast-compatible witharr
andmax
.max (ArrayLike | None) – optional maximum value of the clipped range; if
None
(default) then result will not be clipped to any maximum value. If specified, it should be broadcast-compatible witharr
andmin
.a (ArrayLike | DeprecatedArg) – deprecated alias of the
arr
argument. Will result in aDeprecationWarning
if used.a_min (ArrayLike | None | DeprecatedArg) – deprecated alias of the
min
argument. Will result in aDeprecationWarning
if used.a_max (ArrayLike | None | DeprecatedArg) – deprecated alias of the
max
argument. Will result in aDeprecationWarning
if used.
- Returns:
An array containing values from
arr
, with values smaller thanmin
set tomin
, and values larger thanmax
set tomax
.- Return type:
See also
jax.numpy.minimum()
: Compute the element-wise minimum value of two arrays.jax.numpy.maximum()
: Compute the element-wise maximum value of two arrays.
Examples
>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7]) >>> jnp.clip(arr, 2, 5) Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)