jax.scipy.signal.convolve(in1, in2, mode='full', method='auto', precision=None)[source]#

Convolution of two N-dimensional arrays.

JAX implementation of jax.scipy.signal.convolve().

  • in1 (Array) – left-hand input to the convolution.

  • in2 (Array) – right-hand input to the convolution. Must have in1.ndim == in2.ndim.

  • mode (str) –

    controls the size of the output. Available operations are:

    • "full": (default) output the full convolution of the inputs.

    • "same": return a centered portion of the "full" output which is the same size as in1.

    • "valid": return the portion of the "full" output which do not depend on padding at the array edges.

  • method (str) –

    controls the computation method. Options are

    • "auto": (default) always uses the "direct" method.

    • "direct": lower to jax.lax.conv_general_dilated().

    • "fft": compute the result via a fast Fourier transform.

  • precision (str | Precision | tuple[str, str] | tuple[Precision, Precision] | None) – Specify the precision of the computation. Refer to jax.lax.Precision for a description of available values.


Array containing the convolved result.

Return type:


See also


A few 1D convolution examples:

>>> x = jnp.array([1, 2, 3, 2, 1])
>>> y = jnp.array([1, 1, 1])

Full convolution uses implicit zero-padding at the edges:

>>> jax.scipy.signal.convolve(x, y, mode='full')
Array([1., 3., 6., 7., 6., 3., 1.], dtype=float32)

Specifying mode = 'same' returns a centered convolution the same size as the first input:

>>> jax.scipy.signal.convolve(x, y, mode='same')
Array([3., 6., 7., 6., 3.], dtype=float32)

Specifying mode = 'valid' returns only the portion where the two arrays fully overlap:

>>> jax.scipy.signal.convolve(x, y, mode='valid')
Array([6., 7., 6.], dtype=float32)