# jax.scipy.linalg.expm_frechet#

jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: = None, compute_expm: Literal[True] = True) [source]#
jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: = None, compute_expm: Literal[False])
jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: = None, compute_expm: bool = True)

Compute the Frechet derivative of the matrix exponential.

JAX implementation of `scipy.linalg.expm_frechet()`

Parameters:
• A â€“ array of shape `(..., N, N)`

• E â€“ array of shape `(..., N, N)`; specifies the direction of the derivative.

• compute_expm â€“ if True (default) then compute and return `expm(A)`.

• method â€“ ignored by JAX

Returns:

A tuple `(expm_A, expm_frechet_AE)` if `compute_expm` is True, else the array `expm_frechet_AE`. Both returned arrays have shape `(..., N, N)`.

Examples

We can use this API to compute the matrix exponential of `A`, as well as its derivative in the direction `E`:

```>>> key1, key2 = jax.random.split(jax.random.key(3372))
>>> A = jax.random.normal(key1, (3, 3))
>>> E = jax.random.normal(key2, (3, 3))
>>> expmA, expm_frechet_AE = jax.scipy.linalg.expm_frechet(A, E)
```

This can be equivalently computed using JAXâ€™s automatic differentiation methods; here weâ€™ll compute the derivative of `expm()` in the direction of `E` using `jax.jvp()`, and find the same results:

```>>> expmA2, expm_frechet_AE2 = jax.jvp(jax.scipy.linalg.expm, (A,), (E,))
>>> jnp.allclose(expmA, expmA2)
Array(True, dtype=bool)
>>> jnp.allclose(expm_frechet_AE, expm_frechet_AE2)
Array(True, dtype=bool)
```