# jax.scipy.linalg.inv#

jax.scipy.linalg.inv(a, overwrite_a=False, check_finite=True)[source]#

Return the inverse of a square matrix

JAX implementation of `scipy.linalg.inv()`.

Parameters:
• a (jax.typing.ArrayLike) â€“ array of shape `(..., N, N)` specifying square array(s) to be inverted.

• overwrite_a (bool) â€“ unused in JAX

• check_finite (bool) â€“ unused in JAX

Returns:

Array of shape `(..., N, N)` containing the inverse of the input.

Return type:

Array

Notes

In most cases, explicitly computing the inverse of a matrix is ill-advised. For example, to compute `x = inv(A) @ b`, it is more performant and numerically precise to use a direct solve, such as `jax.scipy.linalg.solve()`.

Example

Compute the inverse of a 3x3 matrix

```>>> a = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> a_inv = jax.scipy.linalg.inv(a)
>>> a_inv
Array([[ 0.        , -0.25      ,  0.5       ],
[-0.25      ,  0.5       , -0.25000003],
[ 0.5       , -0.25      ,  0.        ]], dtype=float32)
```

Check that multiplying with the inverse gives the identity:

```>>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5)
Array(True, dtype=bool)
```

Multiply the inverse by a vector `b`, to find a solution to `a @ x = b`

```>>> b = jnp.array([1., 4., 2.])
>>> a_inv @ b
Array([ 0.  ,  1.25, -0.5 ], dtype=float32)
```

Note, however, that explicitly computing the inverse in such a case can lead to poor performance and loss of precision as the size of the problem grows. Instead, you should use a direct solver like `jax.scipy.linalg.solve()`:

```>>> jax.scipy.linalg.solve(a, b)
Array([ 0.  ,  1.25, -0.5 ], dtype=float32)
```