jax.numpy.corrcoef#
- jax.numpy.corrcoef(x, y=None, rowvar=True)[source]#
Compute the Pearson correlation coefficients.
JAX implementation of
numpy.corrcoef()
.This is a normalized version of the sample covariance computed by
jax.numpy.cov()
. For a sample covariance \(C_{ij}\), the correlation coefficients are\[R_{ij} = \frac{C_{ij}}{\sqrt{C_{ii}C_{jj}}}\]they are constructed such that the values satisfy \(-1 \le R_{ij} \le 1\).
- Parameters:
x (ArrayLike) – array of shape
(M, N)
(ifrowvar
is True), or(N, M)
(ifrowvar
is False) representingN
observations ofM
variables.x
may also be one-dimensional, representingN
observations of a single variable.y (ArrayLike | None) – optional set of additional observations, with the same form as
m
. If specified, theny
is combined withm
, i.e. for the defaultrowvar = True
case,m
becomesjnp.vstack([m, y])
.rowvar (bool) – if True (default) then each row of
m
represents a variable. If False, then each column represents a variable.
- Returns:
A covariance matrix of shape
(M, M)
.- Return type:
See also
jax.numpy.cov()
: compute the covariance matrix.
Examples
Consider these observations of two variables that correlate perfectly. The correlation matrix in this case is a 2x2 matrix of ones:
>>> x = jnp.array([[0, 1, 2], ... [0, 1, 2]]) >>> jnp.corrcoef(x) Array([[1., 1.], [1., 1.]], dtype=float32)
Now consider these observations of two variables that are perfectly anti-correlated. The correlation matrix in this case has
-1
in the off-diagonal:>>> x = jnp.array([[-1, 0, 1], ... [ 1, 0, -1]]) >>> jnp.corrcoef(x) Array([[ 1., -1.], [-1., 1.]], dtype=float32)
Equivalently, these sequences can be specified as separate arguments, in which case they are stacked before continuing the computation.
>>> x = jnp.array([-1, 0, 1]) >>> y = jnp.array([1, 0, -1]) >>> jnp.corrcoef(x, y) Array([[ 1., -1.], [-1., 1.]], dtype=float32)
The entries of the correlation matrix are normalized such that they lie within the range -1 to +1, where +1 indicates perfect correlation and -1 indicates perfect anti-correlation. For example, here is the correlation of 100 points drawn from a 3-dimensional standard normal distribution:
>>> key = jax.random.key(0) >>> x = jax.random.normal(key, shape=(3, 100)) >>> with jnp.printoptions(precision=2): ... print(jnp.corrcoef(x)) [[1. 0.03 0.12] [0.03 1. 0.01] [0.12 0.01 1. ]]