jax.scipy.linalg.sqrtm

Contents

jax.scipy.linalg.sqrtm#

jax.scipy.linalg.sqrtm(A, blocksize=1)[source]#

Compute the matrix square root

JAX implementation of scipy.linalg.sqrtm().

Parameters:
  • A (jax.typing.ArrayLike) – array of shape (N, N)

  • blocksize (int) – Not supported in JAX; JAX always uses blocksize=1.

Returns:

An array of shape (N, N) containing the matrix square root of A

Return type:

Array

Example

>>> a = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> sqrt_a = jax.scipy.linalg.sqrtm(a)
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(sqrt_a)
[[0.92+0.71j 0.54+0.j   0.92-0.71j]
 [0.54+0.j   1.85+0.j   0.54-0.j  ]
 [0.92-0.71j 0.54-0.j   0.92+0.71j]]

By definition, matrix multiplication of the matrix square root with itself should equal the input:

>>> jnp.allclose(a, sqrt_a @ sqrt_a)
Array(True, dtype=bool)

Notes

This function implements the complex Schur method described in [1]. It does not use recursive blocking to speed up computations as a Sylvester Equation solver is not yet available in JAX.

References