jax.scipy.linalg.eigh

Contents

jax.scipy.linalg.eigh#

jax.scipy.linalg.eigh(a: Array | ndarray | bool_ | number | bool | int | float | complex, b: ArrayLike | None = None, lower: bool = True, eigvals_only: Literal[False] = False, overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) tuple[Array, Array][source]#
jax.scipy.linalg.eigh(a: Array | ndarray | bool_ | number | bool | int | float | complex, b: ArrayLike | None = None, lower: bool = True, *, eigvals_only: Literal[True], overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) Array
jax.scipy.linalg.eigh(a: Array | ndarray | bool_ | number | bool | int | float | complex, b: ArrayLike | None, lower: bool, eigvals_only: Literal[True], overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) Array
jax.scipy.linalg.eigh(a: Array | ndarray | bool_ | number | bool | int | float | complex, b: ArrayLike | None = None, lower: bool = True, eigvals_only: bool = False, overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) Array | tuple[Array, Array]

Compute eigenvalues and eigenvectors for a Hermitian matrix

JAX implementation of jax.scipy.linalg.eigh().

Parameters:
  • a – Hermitian input array of shape (..., N, N)

  • b – optional Hermitian input of shape (..., N, N). If specified, compute the generalized eigenvalue problem.

  • lower – if True (default) access only the lower portion of the input matrix. Otherwise access only the upper portion.

  • eigvals_only – If True, compute only the eigenvalues. If False (default) compute both eigenvalues and eigenvectors.

  • type –

    if b is specified, type gives the type of generalized eigenvalue problem to be computed. Denoting (λ, v) as an eigenvalue, eigenvector pair:

    • type = 1 solves a @ v = λ * b @ v (default)

    • type = 2 solves a @ b @ v = λ * v

    • type = 3 solves b @ a @ v = λ * v

  • eigvals – a (low, high) tuple specifying which eigenvalues to compute.

  • overwrite_a – unused by JAX.

  • overwrite_b – unused by JAX.

  • turbo – unused by JAX.

  • check_finite – unused by JAX.

Returns:

A tuple of arrays (eigvals, eigvecs) if eigvals_only is False, otherwise an array eigvals.

  • eigvals: array of shape (..., N) containing the eigenvalues.

  • eigvecs: array of shape (..., N, N) containing the eigenvectors.

See also

Examples

Compute the standard eigenvalue decomposition of a simple 2x2 matrix:

>>> a = jnp.array([[2., 1.],
...                [1., 2.]])
>>> eigvals, eigvecs = jax.scipy.linalg.eigh(a)
>>> eigvals
Array([1., 3.], dtype=float32)
>>> eigvecs
Array([[-0.70710677,  0.70710677],
       [ 0.70710677,  0.70710677]], dtype=float32)

Eigenvectors are orthonormal:

>>> jnp.allclose(eigvecs.T @ eigvecs, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)

Solution satisfies the eigenvalue problem:

>>> jnp.allclose(a @ eigvecs, eigvecs @ jnp.diag(eigvals))
Array(True, dtype=bool)