jax.lax.round#

jax.lax.round(x, rounding_method=RoundingMethod.AWAY_FROM_ZERO)[source]#

Elementwise round.

Rounds values to the nearest integer.

Parameters:
  • x (ArrayLike) – an array or scalar value to round.

  • rounding_method (RoundingMethod) – the method to use when rounding halfway values (e.g., 0.5). See jax.lax.RoundingMethod for possible values.

Returns:

An array containing the elementwise rounding of x.

Return type:

Array