# jax.scipy.linalg.solve_triangular#

jax.scipy.linalg.solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False, overwrite_b=False, debug=None, check_finite=True)[source]#

Solve a triangular linear system of equations

JAX implementation of scipy.linalg.solve_triangular().

This solves a (batched) linear system of equations a @ x = b for x given a triangular matrix a and a vector or matrix b.

Parameters:
• a (jax.typing.ArrayLike) β array of shape (..., N, N). Only part of the array will be accessed, depending on the lower and unit_diagonal arguments.

• b (jax.typing.ArrayLike) β array of shape (..., N) or (..., N, M)

• lower (bool) β If True, only use the lower triangle of the input, If False (default), only use the upper triangle.

• unit_diagonal (bool) β If True, ignore diagonal elements of a and assume they are 1 (default: False).

• trans (int | str) β

specify what properties of a can be assumed. Options are:

• 0 or 'N': solve $$Ax=b$$

• 1 or 'T': solve $$A^Tx=b$$

• 2 or 'C': solve $$A^Hx=b$$

• overwrite_b (bool) β unused by JAX

• debug (Any | None) β unused by JAX

• check_finite (bool) β unused by JAX

Returns:

An array of the same shape as b containing the solution to the linear system.

Return type:

Array

jax.scipy.linalg.solve(): Solve a general linear system.

Example

A simple 3x3 triangular linear system:

>>> A = jnp.array([[1., 2., 3.],
...                [0., 3., 2.],
...                [0., 0., 5.]])
>>> b = jnp.array([10., 8., 5.])
>>> x = jax.scipy.linalg.solve_triangular(A, b)
>>> x
Array([3., 2., 1.], dtype=float32)


Confirming that the result solves the system:

>>> jnp.allclose(A @ x, b)
Array(True, dtype=bool)


Computing the transposed problem:

>>> x = jax.scipy.linalg.solve_triangular(A, b, trans='T')
>>> x
Array([10. , -4. , -3.4], dtype=float32)


Confirming that the result solves the system:

>>> jnp.allclose(A.T @ x, b)
Array(True, dtype=bool)