jax.lax.round

Contents

jax.lax.round#

jax.lax.round(x, rounding_method=RoundingMethod.AWAY_FROM_ZERO)[source]#

Elementwise round.

Rounds values to the nearest integer.

Parameters:
  • x (jax.typing.ArrayLike) – an array or scalar value to round.

  • rounding_method (RoundingMethod) – the method to use when rounding halfway values (e.g., 0.5). See lax.RoundingMethod for the list of possible values.

Returns:

An array containing the elementwise rounding of x.

Return type:

Array