jax.lax.shift_right_arithmetic