jax.scipy.linalg.expm_frechet#
- jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[True] = True) tuple[Array, Array] [source]#
- jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[False]) Array
- jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) Array | tuple[Array, Array]
Compute the Frechet derivative of the matrix exponential.
JAX implementation of
scipy.linalg.expm_frechet()
- Parameters:
A – array of shape
(..., N, N)
E – array of shape
(..., N, N)
; specifies the direction of the derivative.compute_expm – if True (default) then compute and return
expm(A)
.method – ignored by JAX
- Returns:
A tuple
(expm_A, expm_frechet_AE)
ifcompute_expm
is True, else the arrayexpm_frechet_AE
. Both returned arrays have shape(..., N, N)
.
See also
Examples
We can use this API to compute the matrix exponential of
A
, as well as its derivative in the directionE
:>>> key1, key2 = jax.random.split(jax.random.key(3372)) >>> A = jax.random.normal(key1, (3, 3)) >>> E = jax.random.normal(key2, (3, 3)) >>> expmA, expm_frechet_AE = jax.scipy.linalg.expm_frechet(A, E)
This can be equivalently computed using JAX’s automatic differentiation methods; here we’ll compute the derivative of
expm()
in the direction ofE
usingjax.jvp()
, and find the same results:>>> expmA2, expm_frechet_AE2 = jax.jvp(jax.scipy.linalg.expm, (A,), (E,)) >>> jnp.allclose(expmA, expmA2) Array(True, dtype=bool) >>> jnp.allclose(expm_frechet_AE, expm_frechet_AE2) Array(True, dtype=bool)