jax.scipy.linalg.expm_frechet

Contents

jax.scipy.linalg.expm_frechet#

jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[True] = True) tuple[Array, Array][source]#
jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[False]) Array
jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) Array | tuple[Array, Array]

Compute the Frechet derivative of the matrix exponential.

JAX implementation of scipy.linalg.expm_frechet()

Parameters:
  • A – array of shape (..., N, N)

  • E – array of shape (..., N, N); specifies the direction of the derivative.

  • compute_expm – if True (default) then compute and return expm(A).

  • method – ignored by JAX

Returns:

A tuple (expm_A, expm_frechet_AE) if compute_expm is True, else the array expm_frechet_AE. Both returned arrays have shape (..., N, N).

Examples

We can use this API to compute the matrix exponential of A, as well as its derivative in the direction E:

>>> key1, key2 = jax.random.split(jax.random.key(3372))
>>> A = jax.random.normal(key1, (3, 3))
>>> E = jax.random.normal(key2, (3, 3))
>>> expmA, expm_frechet_AE = jax.scipy.linalg.expm_frechet(A, E)

This can be equivalently computed using JAX’s automatic differentiation methods; here we’ll compute the derivative of expm() in the direction of E using jax.jvp(), and find the same results:

>>> expmA2, expm_frechet_AE2 = jax.jvp(jax.scipy.linalg.expm, (A,), (E,))
>>> jnp.allclose(expmA, expmA2)
Array(True, dtype=bool)
>>> jnp.allclose(expm_frechet_AE, expm_frechet_AE2)
Array(True, dtype=bool)