jax.scipy.linalg.expm

Contents

jax.scipy.linalg.expm#

jax.scipy.linalg.expm(A, *, upper_triangular=False, max_squarings=16)[source]#

Compute the matrix exponential of an array.

LAX-backend implementation of scipy.linalg._matfuncs.expm().

In addition to the original NumPy argument(s) listed below, also supports the optional boolean argument upper_triangular to specify whether the A matrix is upper triangular, and the optional argument max_squarings to specify the max number of squarings allowed in the scaling-and-squaring approximation method. Return nan if the actual number of squarings required is more than max_squarings.

The number of required squarings = max(0, ceil(log2(norm(A)) - c) where norm() denotes the L1 norm, and

  • c=2.42 for float64 or complex128,

  • c=1.97 for float32 or complex64

Original docstring below.

Parameters:

A (ndarray) – Input with last two dimensions are square (..., n, n).

Returns:

eA – The resulting matrix exponential with the same shape of A

Return type:

ndarray

References

Parameters:
  • upper_triangular (bool) –

  • max_squarings (int) –