jax.scipy.linalg.schur

Contents

jax.scipy.linalg.schur#

jax.scipy.linalg.schur(a, output='real')[source]#

Compute the Schur decomposition

JAX implementation of scipy.linalg.schur().

The Schur form T of a matrix A satisfies:

\[A = Z T Z^H\]

where Z is unitary, and T is upper-triangular for the complex-valued Schur decomposition (i.e. output="complex") and is quasi-upper-triangular for the real-valued Schur decomposition (i.e. output="real"). In the quasi-triangular case, the diagonal may include 2x2 blocks associated with complex-valued eigenvalue pairs of A.

Parameters:
  • a (jax.typing.ArrayLike) – input array of shape (..., N, N)

  • output (str) – Specify whether to compute the "real" (default) or "complex" Schur decomposition.

Returns:

A tuple of arrays (T, Z)

  • T is a shape (..., N, N) array containing the upper-triangular Schur form of the input.

  • Z is a shape (..., N, N) array containing the unitary Schur transformation matrix.

Return type:

tuple[Array, Array]

See also

Example

A Schur decomposition of a 3x3 matrix:

>>> a = jnp.array([[1., 2., 3.],
...                [1., 4., 2.],
...                [3., 2., 1.]])
>>> T, Z = jax.scipy.linalg.schur(a)

The Schur form T is quasi-upper-triangular in general, but is truly upper-triangular in this case because the input matrix is symmetric:

>>> T  
Array([[-2.0000005 ,  0.5066295 , -0.43360388],
       [ 0.        ,  1.5505103 ,  0.74519426],
       [ 0.        ,  0.        ,  6.449491  ]], dtype=float32)

The transformation matrix Z is unitary:

>>> jnp.allclose(Z.T @ Z, jnp.eye(3), atol=1E-5)
Array(True, dtype=bool)

The input can be reconstructed from the outputs:

>>> jnp.allclose(Z @ T @ Z.T, a)
Array(True, dtype=bool)