jax.numpy.linalg.eigh

Contents

jax.numpy.linalg.eigh#

jax.numpy.linalg.eigh(a, UPLO=None, symmetrize_input=True)[source]#

Computes the eigenvalues and eigenvectors of a Hermitian matrix.

JAX implementation of numpy.linalg.eigh().

Parameters:
  • a (jax.typing.ArrayLike) – array of shape (..., M, M), containing the Hermitian (if complex) or symmetric (if real) matrix.

  • UPLO (str | None) – specifies whether the calculation isdone with the lower triangular part of a ('L', default) or the upper triangular part ('U').

  • symmetrize_input (bool) – if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation.

Returns:

A namedtuple (eigenvalues, eigenvectors) where

  • eigenvalues: an array of shape (..., M) containing the eigenvalues, sorted in ascending order.

  • eigenvectors: an array of shape (..., M, M), where column v[:, i] is the normalized eigenvector corresponding to the eigenvalue w[i].

Return type:

EighResult

See also

Examples

>>> a = jnp.array([[1, -2j],
...                [2j, 1]])
>>> w, v = jnp.linalg.eigh(a)
>>> w
Array([-1.,  3.], dtype=float32)
>>> with jnp.printoptions(precision=3):
...   v
Array([[-0.707+0.j   , -0.707+0.j   ],
       [ 0.   +0.707j,  0.   -0.707j]], dtype=complex64)