jax.numpy.linalg.matrix_power#
- jax.numpy.linalg.matrix_power(a, n)[source]#
Raise a square matrix to an integer power.
JAX implementation of
numpy.linalg.matrix_power()
, implemented via repeated squarings.- Parameters:
a (ArrayLike) – array of shape
(..., M, M)
to be raised to the power n.n (int) – the integer exponent to which the matrix should be raised.
- Returns:
Array of shape
(..., M, M)
containing the matrix power of a to the n.- Return type:
Examples
>>> a = jnp.array([[1., 2.], ... [3., 4.]]) >>> jnp.linalg.matrix_power(a, 3) Array([[ 37., 54.], [ 81., 118.]], dtype=float32) >>> a @ a @ a # equivalent evaluated directly Array([[ 37., 54.], [ 81., 118.]], dtype=float32)
This also supports zero powers:
>>> jnp.linalg.matrix_power(a, 0) Array([[1., 0.], [0., 1.]], dtype=float32)
and also supports negative powers:
>>> with jnp.printoptions(precision=3): ... jnp.linalg.matrix_power(a, -2) Array([[ 5.5 , -2.5 ], [-3.75, 1.75]], dtype=float32)
Negative powers are equivalent to matmul of the inverse:
>>> inv_a = jnp.linalg.inv(a) >>> with jnp.printoptions(precision=3): ... inv_a @ inv_a Array([[ 5.5 , -2.5 ], [-3.75, 1.75]], dtype=float32)