jax.scipy.linalg.expm

Contents

jax.scipy.linalg.expm#

jax.scipy.linalg.expm(A, *, upper_triangular=False, max_squarings=16)[source]#

Compute the matrix exponential

JAX implementation of scipy.linalg.expm().

Parameters:
  • A (jax.typing.ArrayLike) – array of shape (..., N, N)

  • upper_triangular (bool) – if True, then assume that A is upper-triangular. Default=False.

  • max_squarings (int) – The number of squarings in the scaling-and-squaring approximation method (default: 16).

Returns:

An array of shape (..., N, N) containing the matrix exponent of A.

Return type:

Array

Notes

This uses the scaling-and-squaring approximation method, with computational complexity controlled by the optional max_squarings argument. Theoretically, the number of required squarings is max(0, ceil(log2(norm(A))) - c) where norm(A) is the L1 norm and c=2.42 for float64/complex128, or c=1.97 for float32/complex64.