jax.scipy.signal.correlate2d#
- jax.scipy.signal.correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0, precision=None)[source]#
Cross-correlation of two 2-dimensional arrays.
JAX implementation of
scipy.signal.correlate2d()
.- Parameters:
in1 (Array) – left-hand input to the cross-correlation. Must have
in1.ndim == 2
.in2 (Array) – right-hand input to the cross-correlation. Must have
in2.ndim == 2
.mode (str) –
controls the size of the output. Available operations are:
"full"
: (default) output the full cross-correlation of the inputs."same"
: return a centered portion of the"full"
output which is the same size asin1
."valid"
: return the portion of the"full"
output which do not depend on padding at the array edges.
boundary (str) – only
"fill"
is supported.fillvalue (float) – only
0
is supported.method –
controls the computation method. Options are
"auto"
: (default) always uses the"direct"
method."direct"
: lower tojax.lax.conv_general_dilated()
."fft"
: compute the result via a fast Fourier transform.
precision (PrecisionLike | None) – Specify the precision of the computation. Refer to
jax.lax.Precision
for a description of available values.
- Returns:
Array containing the cross-correlation result.
- Return type:
See also
jax.numpy.correlate()
: 1D cross-correlationjax.scipy.signal.correlate()
: ND cross-correlationjax.scipy.signal.convolve()
: ND convolution
Examples
A few 2D correlation examples:
>>> x = jnp.array([[2, 1, 3], ... [1, 3, 1], ... [4, 1, 2]]) >>> y = jnp.array([[1, 3], ... [4, 2]])
Full 2D correlation uses implicit zero-padding at the edges:
>>> jax.scipy.signal.correlate2d(x, y, mode='full') Array([[ 4., 10., 10., 12.], [ 8., 15., 24., 7.], [11., 28., 14., 9.], [12., 7., 7., 2.]], dtype=float32)
Specifying
mode = 'same'
returns a centered 2D correlation of the same size as the first input:>>> jax.scipy.signal.correlate2d(x, y, mode='same') Array([[15., 24., 7.], [28., 14., 9.], [ 7., 7., 2.]], dtype=float32)
Specifying
mode = 'valid'
returns only the portion of 2D correlation where the two arrays fully overlap:>>> jax.scipy.signal.correlate2d(x, y, mode='valid') Array([[15., 24.], [28., 14.]], dtype=float32)