# jax.scipy.signal.fftconvolve#

jax.scipy.signal.fftconvolve(in1, in2, mode='full', axes=None)[source]#

Convolve two N-dimensional arrays using Fast Fourier Transform (FFT).

JAX implementation of `scipy.signal.fftconvolve()`.

Parameters:
• in1 (jax.typing.ArrayLike) â€“ left-hand input to the convolution.

• in2 (jax.typing.ArrayLike) â€“ right-hand input to the convolution. Must have `in1.ndim == in2.ndim`.

• mode (str) â€“

controls the size of the output. Available operations are:

• `"full"`: (default) output the full convolution of the inputs.

• `"same"`: return a centered portion of the `"full"` output which is the same size as `in1`.

• `"valid"`: return the portion of the `"full"` output which do not depend on padding at the array edges.

• axes (Sequence[int] | None) â€“ optional sequence of axes along which to apply the convolution.

Returns:

Array containing the convolved result.

Return type:

Array

Examples

A few 1D convolution examples. Because FFT-based convolution is approximate, We use `jax.numpy.printoptions()` below to adjust the printing precision:

```>>> x = jnp.array([1, 2, 3, 2, 1])
>>> y = jnp.array([1, 1, 1])
```

Full convolution uses implicit zero-padding at the edges:

```>>> with jax.numpy.printoptions(precision=3):
...   print(jax.scipy.signal.fftconvolve(x, y, mode='full'))
[1. 3. 6. 7. 6. 3. 1.]
```

Specifying `mode = 'same'` returns a centered convolution the same size as the first input:

```>>> with jax.numpy.printoptions(precision=3):
...   print(jax.scipy.signal.fftconvolve(x, y, mode='same'))
[3. 6. 7. 6. 3.]
```

Specifying `mode = 'valid'` returns only the portion where the two arrays fully overlap:

```>>> with jax.numpy.printoptions(precision=3):
...   print(jax.scipy.signal.fftconvolve(x, y, mode='valid'))
[6. 7. 6.]
```