jax.numpy.diff(a, n=1, axis=-1)[source]ΒΆ

Calculate the n-th discrete difference along the given axis.

LAX-backend implementation of diff(). Original docstring below.

The first difference is given by out[i] = a[i+1] - a[i] along the given axis, higher differences are calculated by using diff recursively.

  • a (array_like) – Input array

  • n (int, optional) – The number of times values are differenced. If zero, the input is returned as-is.

  • axis (int, optional) – The axis along which the difference is taken, default is the last axis.


diff – The n-th differences. The shape of the output is the same as a except along axis where the dimension is smaller by n. The type of the output is the same as the type of the difference between any two elements of a. This is the same as the type of a in most cases. A notable exception is datetime64, which results in a timedelta64 output array.

Return type



Type is preserved for boolean arrays, so the result will contain False when consecutive elements are the same and True when they differ.

For unsigned integer arrays, the results will also be unsigned. This should not be surprising, as the result is consistent with calculating the difference directly:

>>> u8_arr = np.array([1, 0], dtype=np.uint8)
>>> np.diff(u8_arr)
array([255], dtype=uint8)
>>> u8_arr[1,...] - u8_arr[0,...]

If this is not desirable, then the array should be cast to a larger integer type first:

>>> i16_arr = u8_arr.astype(np.int16)
>>> np.diff(i16_arr)
array([-1], dtype=int16)


>>> x = np.array([1, 2, 4, 7, 0])
>>> np.diff(x)
array([ 1,  2,  3, -7])
>>> np.diff(x, n=2)
array([  1,   1, -10])
>>> x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]])
>>> np.diff(x)
array([[2, 3, 4],
       [5, 1, 2]])
>>> np.diff(x, axis=0)
array([[-1,  2,  0, -2]])
>>> x = np.arange('1066-10-13', '1066-10-16', dtype=np.datetime64)
>>> np.diff(x)
array([1, 1], dtype='timedelta64[D]')