# jax.scipy.linalg.solve#

jax.scipy.linalg.solve(a, b, lower=False, overwrite_a=False, overwrite_b=False, debug=False, check_finite=True, assume_a='gen')[source]#

Solve a linear system of equations

JAX implementation of `scipy.linalg.solve()`.

This solves a (batched) linear system of equations `a @ x = b` for `x` given `a` and `b`.

Parameters:
• a (jax.typing.ArrayLike) β array of shape `(..., N, N)`.

• b (jax.typing.ArrayLike) β array of shape `(..., N)` or `(..., N, M)`

• lower (bool) β Referenced only if `assume_a != 'gen'`. If True, only use the lower triangle of the input, If False (default), only use the upper triangle.

• assume_a (str) β

specify what properties of `a` can be assumed. Options are:

• `"gen"`: generic matrix (default)

• `"sym"`: symmetric matrix

• `"her"`: hermitian matrix

• `"pos"`: positive-definite matrix

• overwrite_a (bool) β unused by JAX

• overwrite_b (bool) β unused by JAX

• debug (bool) β unused by JAX

• check_finite (bool) β unused by JAX

Returns:

An array of the same shape as `b` containing the solution to the linear system.

Return type:

Array

Examples

A simple 3x3 linear system:

```>>> A = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> b = jnp.array([14., 16., 10.])
>>> x = jax.scipy.linalg.solve(A, b)
>>> x
Array([1., 2., 3.], dtype=float32)
```

Confirming that the result solves the system:

```>>> jnp.allclose(A @ x, b)
Array(True, dtype=bool)
```