jax.numpy.nextafter#

jax.numpy.nextafter(x, y, /)[source]#

Return element-wise next floating point value after x towards y.

JAX implementation of numpy.nextafter.

Parameters:
  • x (ArrayLike) – scalar or array. Specifies the value after which the next number is found.

  • y (ArrayLike) – scalar or array. Specifies the direction towards which the next number is found. x and y should either have same shape or be broadcast compatible.

Returns:

An array containing the next representable number of x in the direction of y.

Return type:

Array

Examples

>>> jnp.nextafter(2, 1)  
Array(1.9999999, dtype=float32, weak_type=True)
>>> x = jnp.array([3, -2, 1])
>>> y = jnp.array([2, -1, 2])
>>> jnp.nextafter(x, y)  
Array([ 2.9999998, -1.9999999,  1.0000001], dtype=float32)