jax.scipy.integrate.trapezoid

Contents

jax.scipy.integrate.trapezoid#

jax.scipy.integrate.trapezoid(y, x=None, dx=1.0, axis=-1)[source]#

Integrate along the given axis using the composite trapezoidal rule.

JAX implementation of scipy.integrate.trapezoid()

The trapezoidal rule approximates the integral under a curve by summing the areas of trapezoids formed between adjacent data points.

Parameters:
  • y (ArrayLike) – array of data to integrate.

  • x (ArrayLike | None) – optional array of sample points corresponding to the y values. If not provided, x defaults to equally spaced with spacing given by dx.

  • dx (ArrayLike) – The spacing between sample points when x is None (default: 1.0).

  • axis (int) – The axis along which to integrate (default: -1)

Returns:

The definite integral approximated by the trapezoidal rule.

Return type:

Array

See also

jax.numpy.trapezoid(): NumPy-style API for trapezoidal integration

Examples

Integrate over a regular grid, with spacing 1.0:

>>> y = jnp.array([1, 2, 3, 2, 3, 2, 1])
>>> jax.scipy.integrate.trapezoid(y, dx=1.0)
Array(13., dtype=float32)

Integrate over an irregular grid:

>>> x = jnp.array([0, 2, 5, 7, 10, 15, 20])
>>> jax.scipy.integrate.trapezoid(y, x)
Array(43., dtype=float32)

Approximate \(\int_0^{2\pi} \sin^2(x)dx\), which equals \(\pi\):

>>> x = jnp.linspace(0, 2 * jnp.pi, 1000)
>>> y = jnp.sin(x) ** 2
>>> result = jax.scipy.integrate.trapezoid(y, x)
>>> jnp.allclose(result, jnp.pi)
Array(True, dtype=bool)