jax.scipy.linalg.solve_triangular

jax.scipy.linalg.solve_triangular#

jax.scipy.linalg.solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False, overwrite_b=False, debug=None, check_finite=True)[source]#

Solve a triangular linear system of equations

JAX implementation of scipy.linalg.solve_triangular().

This solves a (batched) linear system of equations a @ x = b for x given a triangular matrix a and a vector or matrix b.

Parameters:
  • a (jax.typing.ArrayLike) – array of shape (..., N, N). Only part of the array will be accessed, depending on the lower and unit_diagonal arguments.

  • b (jax.typing.ArrayLike) – array of shape (..., N) or (..., N, M)

  • lower (bool) – If True, only use the lower triangle of the input, If False (default), only use the upper triangle.

  • unit_diagonal (bool) – If True, ignore diagonal elements of a and assume they are 1 (default: False).

  • trans (int | str) –

    specify what properties of a can be assumed. Options are:

    • 0 or 'N': solve \(Ax=b\)

    • 1 or 'T': solve \(A^Tx=b\)

    • 2 or 'C': solve \(A^Hx=b\)

  • overwrite_b (bool) – unused by JAX

  • debug (Any | None) – unused by JAX

  • check_finite (bool) – unused by JAX

Returns:

An array of the same shape as b containing the solution to the linear system.

Return type:

Array

See also

jax.scipy.linalg.solve(): Solve a general linear system.

Example

A simple 3x3 triangular linear system:

>>> A = jnp.array([[1., 2., 3.],
...                [0., 3., 2.],
...                [0., 0., 5.]])
>>> b = jnp.array([10., 8., 5.])
>>> x = jax.scipy.linalg.solve_triangular(A, b)
>>> x
Array([3., 2., 1.], dtype=float32)

Confirming that the result solves the system:

>>> jnp.allclose(A @ x, b)
Array(True, dtype=bool)

Computing the transposed problem:

>>> x = jax.scipy.linalg.solve_triangular(A, b, trans='T')
>>> x
Array([10. , -4. , -3.4], dtype=float32)

Confiriming that the result solves the system:

>>> jnp.allclose(A.T @ x, b)
Array(True, dtype=bool)