jax.numpy.linalg.matrix_power

Contents

jax.numpy.linalg.matrix_power#

jax.numpy.linalg.matrix_power(a, n)[source]#

Raise a square matrix to an integer power.

JAX implementation of numpy.linalg.matrix_power(), implemented via repeated squarings.

Parameters:
  • a (jax.typing.ArrayLike) – array of shape (..., M, M) to be raised to the power n.

  • n (int) – the integer exponent to which the matrix should be raised.

Returns:

Array of shape (..., M, M) containing the matrix power of a to the n.

Return type:

Array

Examples

>>> a = jnp.array([[1., 2.],
...                [3., 4.]])
>>> jnp.linalg.matrix_power(a, 3)
Array([[ 37.,  54.],
       [ 81., 118.]], dtype=float32)
>>> a @ a @ a  # equivalent evaluated directly
Array([[ 37.,  54.],
       [ 81., 118.]], dtype=float32)

This also supports zero powers:

>>> jnp.linalg.matrix_power(a, 0)
Array([[1., 0.],
       [0., 1.]], dtype=float32)

and also supports negative powers:

>>> with jnp.printoptions(precision=3):
...   jnp.linalg.matrix_power(a, -2)
Array([[ 5.5 , -2.5 ],
       [-3.75,  1.75]], dtype=float32)

Negative powers are equivalent to matmul of the inverse:

>>> inv_a = jnp.linalg.inv(a)
>>> with jnp.printoptions(precision=3):
...   inv_a @ inv_a
Array([[ 5.5 , -2.5 ],
       [-3.75,  1.75]], dtype=float32)