jax.lax.linalg.lu

Contents

jax.lax.linalg.lu#

jax.lax.linalg.lu(x)[source]#

LU decomposition with partial pivoting.

Computes the matrix decomposition:

\[P.A = L.U\]

where \(P\) is a permutation of the rows of \(A\), \(L\) is a lower-triangular matrix with unit-diagonal elements, and \(U\) is an upper-triangular matrix.

Parameters:

x (jax.typing.ArrayLike) – A batch of matrices with shape [..., m, n].

Returns:

A tuple (lu, pivots, permutation).

lu is a batch of matrices with the same shape and dtype as x containing the \(L\) matrix in its lower triangle and the \(U\) matrix in its upper triangle. The (unit) diagonal elements of \(L\) are not represented explicitly.

pivots is an int32 array with shape [..., min(m, n)] representing a sequence of row swaps that should be performed on \(A\).

permutation is an alternative representation of the sequence of row swaps as a permutation, represented as an int32 array with shape [..., m].

Return type:

tuple[Array, Array, Array]