jax.numpy.linalg.eigvals

Contents

jax.numpy.linalg.eigvals#

jax.numpy.linalg.eigvals(a)[source]#

Computes the eigenvalues of a general matrix.

JAX implementation of numpy.linalg.eigvals().

Parameters:

a (jax.typing.ArrayLike) – array of shape (..., M, M) for which to compute the eigenvalues.

Returns:

An array of shape (..., M) containing the eigenvalues.

Return type:

Array

See also

Notes

  • This differs from numpy.linalg.eigvals() in that the return type of jax.numpy.linalg.eigvals() is always complex64 for 32-bit input, and complex128 for 64-bit input.

  • At present, non-symmetric eigendecomposition is only implemented on the CPU backend.

Examples

>>> a = jnp.array([[1., 2.],
...                [2., 1.]])
>>> w = jnp.linalg.eigvals(a)
>>> with jnp.printoptions(precision=2):
...  w
Array([ 3.+0.j, -1.+0.j], dtype=complex64)