jax.numpy.gradient

Contents

jax.numpy.gradient#

jax.numpy.gradient(f, *varargs, axis=None, edge_order=None)[source]#

Return the gradient of an N-dimensional array.

LAX-backend implementation of numpy.gradient().

Original docstring below.

The gradient is computed using second order accurate central differences in the interior points and either first or second order accurate one-sides (forward or backwards) differences at the boundaries. The returned gradient hence has the same shape as the input array.

Parameters:
  • f (array_like) – An N-dimensional array containing samples of a scalar function.

  • varargs (list of scalar or array, optional) –

    Spacing between f values. Default unitary spacing for all dimensions. Spacing can be specified using:

    1. single scalar to specify a sample distance for all dimensions.

    2. N scalars to specify a constant sample distance for each dimension. i.e. dx, dy, dz, …

    3. N arrays to specify the coordinates of the values along each dimension of F. The length of the array must match the size of the corresponding dimension

    4. Any combination of N scalars/arrays with the meaning of 2. and 3.

    If axis is given, the number of varargs must equal the number of axes. Default: 1.

  • axis (None or int or tuple of ints, optional) – Gradient is calculated only along the given axis or axes The default (axis = None) is to calculate the gradient for all the axes of the input array. axis may be negative, in which case it counts from the last to the first axis.

  • edge_order (int | None)

Returns:

gradient – A list of ndarrays (or a single ndarray if there is only one dimension) corresponding to the derivatives of f with respect to each dimension. Each derivative has the same shape as f.

Return type:

ndarray or list of ndarray

References