jax.lax.linalg.tridiagonal_solve

jax.lax.linalg.tridiagonal_solve#

jax.lax.linalg.tridiagonal_solve(dl, d, du, b)[source]#

Computes the solution of a tridiagonal linear system.

This function computes the solution of a tridiagonal linear system:

\[A . X = B\]
Parameters:
  • dl (Array) – A batch of vectors with shape [..., m]. The lower diagonal of A: dl[i] := A[i, i-1] for i in [0,m). Note that dl[0] = 0.

  • d (Array) – A batch of vectors with shape [..., m]. The middle diagonal of A: d[i]  := A[i, i] for i in [0,m).

  • du (Array) – A batch of vectors with shape [..., m]. The upper diagonal of A: du[i] := A[i, i+1] for i in [0,m). Note that dl[m - 1] = 0.

  • b (Array) – Right hand side matrix.

Return type:

Array

Returns:

Solution X of tridiagonal system.