# jax.numpy.correlate#

jax.numpy.correlate(a, v, mode='valid', *, precision=None, preferred_element_type=None)[source]#

Correlation of two one dimensional arrays.

JAX implementation of numpy.correlate().

Correlation of one dimensional arrays is defined as:

$c_k = \sum_j a_{k + j} \overline{v_j}$

where $$\overline{v_j}$$ is the complex conjugate of $$v_j$$.

Parameters:
• a (ArrayLike) â€“ left-hand input to the correlation. Must have a.ndim == 1.

• v (ArrayLike) â€“ right-hand input to the correlation. Must have v.ndim == 1.

• mode (str) â€“

controls the size of the output. Available operations are:

• "full": output the full correlation of the inputs.

• "same": return a centered portion of the "full" output which is the same size as a.

• "valid": (default) return the portion of the "full" output which do not depend on padding at the array edges.

• precision (PrecisionLike) â€“ Specify the precision of the computation. Refer to jax.lax.Precision for a description of available values.

• preferred_element_type (DTypeLike | None) â€“ A datatype, indicating to accumulate results to and return a result with that datatype. Default is None, which means the default accumulation type for the input types.

Returns:

Array containing the cross-correlation result.

Return type:

Array

Examples

>>> x = jnp.array([1, 2, 3, 2, 1])
>>> y = jnp.array([4, 5, 6])


Since default mode = 'valid', jax.numpy.correlate returns only the portion of correlation where the two arrays fully overlap:

>>> jnp.correlate(x, y)
Array([32., 35., 28.], dtype=float32)


Specifying mode = 'full' returns full correlation using implicit zero-padding at the edges.

>>> jnp.correlate(x, y, mode='full')
Array([ 6., 17., 32., 35., 28., 13.,  4.], dtype=float32)


Specifying mode = 'same' returns a centered correlation the same size as the first input:

>>> jnp.correlate(x, y, mode='same')
Array([17., 32., 35., 28., 13.], dtype=float32)


If both the inputs arrays are real-valued and symmetric then the result will also be symmetric and will be equal to the result of jax.numpy.convolve.

>>> x1 = jnp.array([1, 2, 3, 2, 1])
>>> y1 = jnp.array([4, 5, 4])
>>> jnp.correlate(x1, y1, mode='full')
Array([ 4., 13., 26., 31., 26., 13.,  4.], dtype=float32)
>>> jnp.convolve(x1, y1, mode='full')
Array([ 4., 13., 26., 31., 26., 13.,  4.], dtype=float32)


For complex-valued inputs:

>>> x2 = jnp.array([3+1j, 2, 2-3j])
>>> y2 = jnp.array([4, 2-5j, 1])
>>> jnp.correlate(x2, y2, mode='full')
Array([ 3. +1.j,  3.+17.j, 18.+11.j, 27. +4.j,  8.-12.j], dtype=complex64)