jax.numpy.reciprocal

Contents

jax.numpy.reciprocal#

jax.numpy.reciprocal(x, /)[source]#

Calculate element-wise reciprocal of the input.

JAX implementation of numpy.reciprocal.

The reciprocal is calculated by 1/x.

Parameters:

x (ArrayLike) – input array or scalar.

Returns:

An array of same shape as x containing the reciprocal of each element of x.

Return type:

Array

Note

For integer inputs, np.reciprocal returns rounded integer output, while jnp.reciprocal promotes integer inputs to floating point.

Examples

>>> jnp.reciprocal(2)
Array(0.5, dtype=float32, weak_type=True)
>>> jnp.reciprocal(0.)
Array(inf, dtype=float32, weak_type=True)
>>> x = jnp.array([1, 5., 4.])
>>> jnp.reciprocal(x)
Array([1.  , 0.2 , 0.25], dtype=float32)