Stops gradient computation.

Operationally stop_gradient is the identity function, that is, it returns argument x unchanged. However, stop_gradient prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations, stop_gradient stops gradients for all of them.

For example:

>>> jax.grad(lambda x: x**2)(3.)
DeviceArray(6., dtype=float32, weak_type=True)
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
DeviceArray(0., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
DeviceArray(2., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
DeviceArray(0., dtype=float32, weak_type=True)

x (~T) –

Return type