jax.numpy.gradient#
- jax.numpy.gradient(f, *varargs, axis=None, edge_order=None)[source]#
Return the gradient of an N-dimensional array.
LAX-backend implementation of
numpy.gradient()
.Original docstring below.
The gradient is computed using second order accurate central differences in the interior points and either first or second order accurate one-sides (forward or backwards) differences at the boundaries. The returned gradient hence has the same shape as the input array.
- Parameters:
f (array_like) – An N-dimensional array containing samples of a scalar function.
varargs (list of scalar or array, optional) –
Spacing between f values. Default unitary spacing for all dimensions. Spacing can be specified using:
single scalar to specify a sample distance for all dimensions.
N scalars to specify a constant sample distance for each dimension. i.e. dx, dy, dz, …
N arrays to specify the coordinates of the values along each dimension of F. The length of the array must match the size of the corresponding dimension
Any combination of N scalars/arrays with the meaning of 2. and 3.
If axis is given, the number of varargs must equal the number of axes. Default: 1.
axis (None or int or tuple of ints, optional) – Gradient is calculated only along the given axis or axes The default (axis = None) is to calculate the gradient for all the axes of the input array. axis may be negative, in which case it counts from the last to the first axis.
- Returns:
gradient – A list of ndarrays (or a single ndarray if there is only one dimension) corresponding to the derivatives of f with respect to each dimension. Each derivative has the same shape as f.
- Return type:
ndarray or list of ndarray
References
- Parameters:
edge_order (int | None) –