jax.lax.stop_gradient

jax.lax.stop_gradient(x)[source]

Stops gradient computation.

Operationally stop_gradient is the identity function, that is, it returns argument x unchanged. However, stop_gradient prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations, stop_gradient stops gradients for all of them.

For example:

>>> jax.grad(lambda x: x**2)(3.)
DeviceArray(6., dtype=float32)
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
array(0., dtype=float32)
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
DeviceArray(2., dtype=float32)
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
array(0., dtype=float32)