jax.numpy.diff#

jax.numpy.diff(a, n=1, axis=-1, prepend=None, append=None)[source]#

Calculate n-th order difference between array elements along a given axis.

JAX implementation of numpy.diff().

The first order difference is computed by a[i+1] - a[i], and the n-th order difference is computed n times recursively.

Parameters:
  • a (ArrayLike) – input array. Must have a.ndim >= 1.

  • n (int) – int, optional, default=1. Order of the difference. Specifies the number of times the difference is computed. If n=0, no difference is computed and input is returned as is.

  • axis (int) – int, optional, default=-1. Specifies the axis along which the difference is computed. The difference is computed along axis -1 by default.

  • prepend (ArrayLike | None) – scalar or array, optional, default=None. Specifies the values to be prepended along axis before computing the difference.

  • append (ArrayLike | None) – scalar or array, optional, default=None. Specifies the values to be appended along axis before computing the difference.

Returns:

An array containing the n-th order difference between the elements of a.

Return type:

Array

See also

Examples

jnp.diff computes the first order difference along axis, by default.

>>> a = jnp.array([[1, 5, 2, 9],
...                [3, 8, 7, 4]])
>>> jnp.diff(a)
Array([[ 4, -3,  7],
       [ 5, -1, -3]], dtype=int32)

When n = 2, second order difference is computed along axis.

>>> jnp.diff(a, n=2)
Array([[-7, 10],
       [-6, -2]], dtype=int32)

When prepend = 2, it is prepended to a along axis before computing the difference.

>>> jnp.diff(a, prepend=2)
Array([[-1,  4, -3,  7],
       [ 1,  5, -1, -3]], dtype=int32)

When append = jnp.array([[3],[1]]), it is appended to a along axis before computing the difference.

>>> jnp.diff(a, append=jnp.array([[3],[1]]))
Array([[ 4, -3,  7, -6],
       [ 5, -1, -3, -3]], dtype=int32)