jax.scipy.linalg.rsf2csf

Contents

jax.scipy.linalg.rsf2csf#

jax.scipy.linalg.rsf2csf(T, Z, check_finite=True)[source]#

Convert real Schur form to complex Schur form.

JAX implementation of scipy.linalg.rsf2csf().

Parameters:
  • T (jax.typing.ArrayLike) – array of shape (..., N, N) containing the real Schur form of the input.

  • Z (jax.typing.ArrayLike) – array of shape (..., N, N) containing the corresponding Schur transformation matrix.

  • check_finite (bool) – unused by JAX

Returns:

A tuple of arrays (T, Z) of the same shape as the inputs, containing the Complex Schur form and the associated Schur transformation matrix.

Return type:

tuple[Array, Array]

See also

jax.scipy.linalg.schur(): Schur decomposition

Example

>>> A = jnp.array([[0., 3., 3.],
...                [0., 1., 2.],
...                [2., 0., 1.]])
>>> Tr, Zr = jax.scipy.linalg.schur(A)
>>> Tc, Zc = jax.scipy.linalg.rsf2csf(Tr, Zr)

Both the real and complex form can be used to reconstruct the input matrix to float32 precision:

>>> jnp.allclose(Zr @ Tr @ Zr.T, A, atol=1E-5)
Array(True, dtype=bool)
>>> jnp.allclose(Zc @ Tc @ Zc.conj().T, A, atol=1E-5)
Array(True, dtype=bool)

The real-valued Schur form is only quasi-upper-triangular, as we can see in this case:

>>> with jax.numpy.printoptions(precision=2, suppress=True):
...   print(Tr)
[[ 3.76 -2.17  1.38]
 [ 0.   -0.88 -0.35]
 [ 0.    2.37 -0.88]]

By contrast, the complex form is truely upper-triangular:

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(Tc)
[[ 3.76+0.j    1.29-0.78j  2.02-0.5j ]
 [ 0.  +0.j   -0.88+0.91j -2.02+0.j  ]
 [ 0.  +0.j    0.  +0.j   -0.88-0.91j]]