jax.scipy.linalg.expm_frechet

Contents

jax.scipy.linalg.expm_frechet#

jax.scipy.linalg.expm_frechet(A, E, *, method=None, compute_expm=True)[source]#

Frechet derivative of the matrix exponential of A in the direction E.

LAX-backend implementation of scipy.linalg._expm_frechet.expm_frechet().

Does not currently support the Scipy argument jax.numpy.asarray_chkfinite, because jax.numpy.asarray_chkfinite does not exist at the moment. Does not support the method='blockEnlarge' argument.

Original docstring below.

Parameters:
  • A ((N, N) array_like) – Matrix of which to take the matrix exponential.

  • E ((N, N) array_like) – Matrix direction in which to take the Frechet derivative.

  • method (str, optional) –

    Choice of algorithm. Should be one of

    • SPS (default)

    • blockEnlarge

  • compute_expm (bool, optional) – Whether to compute also expm_A in addition to expm_frechet_AE. Default is True.

Return type:

Array | tuple[Array, Array]

Returns:

  • expm_A (ndarray) – Matrix exponential of A.

  • expm_frechet_AE (ndarray) – Frechet derivative of the matrix exponential of A in the direction E.

  • For compute_expm = False, only expm_frechet_AE is returned.

References