jax.numpy.diff
jax.numpy.diff#
- jax.numpy.diff(a, n=1, axis=- 1, prepend=None, append=None)[source]#
Calculate the n-th discrete difference along the given axis.
LAX-backend implementation of
numpy.diff()
.Original docstring below.
The first difference is given by
out[i] = a[i+1] - a[i]
along the given axis, higher differences are calculated by using diff recursively.- Parameters
a (array_like) â€“ Input array
n (int, optional) â€“ The number of times values are differenced. If zero, the input is returned as-is.
axis (int, optional) â€“ The axis along which the difference is taken, default is the last axis.
prepend (array_like, optional) â€“ Values to prepend or append to a along axis prior to performing the difference. Scalar values are expanded to arrays with length 1 in the direction of axis and the shape of the input array in along all other axes. Otherwise the dimension and shape must match a except along axis.
append (array_like, optional) â€“ Values to prepend or append to a along axis prior to performing the difference. Scalar values are expanded to arrays with length 1 in the direction of axis and the shape of the input array in along all other axes. Otherwise the dimension and shape must match a except along axis.
- Returns
diff â€“ The n-th differences. The shape of the output is the same as a except along axis where the dimension is smaller by n. The type of the output is the same as the type of the difference between any two elements of a. This is the same as the type of a in most cases. A notable exception is datetime64, which results in a timedelta64 output array.
- Return type
ndarray