jax.numpy.linalg.eig

Contents

jax.numpy.linalg.eig#

jax.numpy.linalg.eig(a)[source]#

Computes the eigenvalues and eigenvectors of a square array.

JAX implementation of numpy.linalg.eig().

Parameters:

a (jax.typing.ArrayLike) – array of shape (..., M, M) for which to compute the eigenvalues and vectors.

Returns:

A tuple (eigenvalues, eigenvectors) with

  • eigenvalues: an array of shape (..., M) containing the eigenvalues.

  • eigenvectors: an array of shape (..., M, M), where column v[:, i] is the eigenvector corresponding to the eigenvalue w[i].

Return type:

tuple[Array, Array]

Notes

  • This differs from numpy.linalg.eig() in that the return type of jax.numpy.linalg.eig() is always complex64 for 32-bit input, and complex128 for 64-bit input.

  • At present, non-symmetric eigendecomposition is only implemented on the CPU backend.

See also

Examples

>>> a = jnp.array([[1., 2.],
...                [2., 1.]])
>>> w, v = jnp.linalg.eig(a)
>>> with jax.numpy.printoptions(precision=4):
...   w
Array([ 3.+0.j, -1.+0.j], dtype=complex64)
>>> v
Array([[ 0.70710677+0.j, -0.70710677+0.j],
       [ 0.70710677+0.j,  0.70710677+0.j]], dtype=complex64)