jax.scipy.linalg.expm#
- jax.scipy.linalg.expm(A, *, upper_triangular=False, max_squarings=16)[source]#
Compute the matrix exponential
JAX implementation of
scipy.linalg.expm()
.- Parameters:
- Returns:
An array of shape
(..., N, N)
containing the matrix exponent ofA
.- Return type:
Notes
This uses the scaling-and-squaring approximation method, with computational complexity controlled by the optional
max_squarings
argument. Theoretically, the number of required squarings ismax(0, ceil(log2(norm(A))) - c)
wherenorm(A)
is the L1 norm andc=2.42
for float64/complex128, orc=1.97
for float32/complex64.See also
Examples
expm
is the matrix exponential, and has similar properties to the more familiar scalar exponential. For scalarsa
andb
, \(e^{a + b} = e^a e^b\). However, for matrices, this property only holds whenA
andB
commute (AB = BA
). In this case,expm(A+B) = expm(A) @ expm(B)
>>> A = jnp.array([[2, 0], ... [0, 1]]) >>> B = jnp.array([[3, 0], ... [0, 4]]) >>> jnp.allclose(jax.scipy.linalg.expm(A+B), ... jax.scipy.linalg.expm(A) @ jax.scipy.linalg.expm(B), ... rtol=0.0001) Array(True, dtype=bool)
If a matrix
X
is invertible, thenexpm(X @ A @ inv(X)) = X @ expm(A) @ inv(X)
>>> X = jnp.array([[3, 1], ... [2, 5]]) >>> X_inv = jax.scipy.linalg.inv(X) >>> jnp.allclose(jax.scipy.linalg.expm(X @ A @ X_inv), ... X @ jax.scipy.linalg.expm(A) @ X_inv) Array(True, dtype=bool)