jax.scipy.linalg.lu_solve

Contents

jax.scipy.linalg.lu_solve#

jax.scipy.linalg.lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True)[source]#

Solve a linear system using an LU factorization

JAX implementation of scipy.linalg.lu_solve(). Uses the output of jax.scipy.linalg.lu_factor().

Parameters:
  • lu_and_piv (tuple[Array, jax.typing.ArrayLike]) – (lu, piv), output of lu_factor(). lu is an array of shape (..., M, N), containing L in its lower triangle and U in its upper. piv is an array of shape (..., K), with K = min(M, N), which encodes the pivots.

  • b (jax.typing.ArrayLike) – right-hand-side of linear system. Must have shape (..., M)

  • trans (int) –

    type of system to solve. Options are:

    • 0: \(A x = b\)

    • 1: \(A^Tx = b\)

    • 2: \(A^Hx = b\)

  • overwrite_b (bool) – unused by JAX

  • check_finite (bool) – unused by JAX

Returns:

Array of shape (..., N) representing the solution of the linear system.

Return type:

Array

Example

Solving a small linear system via LU factorization:

>>> a = jnp.array([[2., 1.],
...                [1., 2.]])

Compute the lu factorization via lu_factor(), and use it to solve a linear equation via lu_solve().

>>> b = jnp.array([3., 4.])
>>> lufac = jax.scipy.linalg.lu_factor(a)
>>> y = jax.scipy.linalg.lu_solve(lufac, b)
>>> y
Array([0.6666666, 1.6666667], dtype=float32)

Check that the result is consistent:

>>> jnp.allclose(a @ y, b)
Array(True, dtype=bool)