jax.numpy.linalg.solve

Contents

jax.numpy.linalg.solve#

jax.numpy.linalg.solve(a, b)[source]#

Solve a linear system of equations

JAX implementation of numpy.linalg.solve().

This solves a (batched) linear system of equations a @ x = b for x given a and b.

Parameters:
  • a (jax.typing.ArrayLike) – array of shape (..., N, N).

  • b (jax.typing.ArrayLike) – array of shape (N,) (for 1-dimensional right-hand-side) or (..., N, M) (for batched 2-dimensional right-hand-side).

Returns:

An array containing the result of the linear solve. The result has shape (..., N) if b is of shape (N,), and has shape (..., N, M) otherwise.

Return type:

Array

See also

Example

A simple 3x3 linear system:

>>> A = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> b = jnp.array([14., 16., 10.])
>>> x = jnp.linalg.solve(A, b)
>>> x
Array([1., 2., 3.], dtype=float32)

Confirming that the result solves the system:

>>> jnp.allclose(A @ x, b)
Array(True, dtype=bool)