jax.scipy.linalg.solve

Contents

jax.scipy.linalg.solve#

jax.scipy.linalg.solve(a, b, lower=False, overwrite_a=False, overwrite_b=False, debug=False, check_finite=True, assume_a='gen')[source]#

Solve a linear system of equations

JAX implementation of scipy.linalg.solve().

This solves a (batched) linear system of equations a @ x = b for x given a and b.

Parameters:
  • a (jax.typing.ArrayLike) – array of shape (..., N, N).

  • b (jax.typing.ArrayLike) – array of shape (..., N) or (..., N, M)

  • lower (bool) – Referenced only if assume_a != 'gen'. If True, only use the lower triangle of the input, If False (default), only use the upper triangle.

  • assume_a (str) –

    specify what properties of a can be assumed. Options are:

    • "gen": generic matrix (default)

    • "sym": symmetric matrix

    • "her": hermitian matrix

    • "pos": positive-definite matrix

  • overwrite_a (bool) – unused by JAX

  • overwrite_b (bool) – unused by JAX

  • debug (bool) – unused by JAX

  • check_finite (bool) – unused by JAX

Returns:

An array of the same shape as b containing the solution to the linear system.

Return type:

Array

See also

Example

A simple 3x3 linear system:

>>> A = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> b = jnp.array([14., 16., 10.])
>>> x = jax.scipy.linalg.solve(A, b)
>>> x
Array([1., 2., 3.], dtype=float32)

Confirming that the result solves the system:

>>> jnp.allclose(A @ x, b)
Array(True, dtype=bool)