# jax.scipy.linalg.lu#

jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = True) [source]#
jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, check_finite: bool = True)
jax.scipy.linalg.lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True)

Compute the LU decomposition

JAX implementation of scipy.linalg.lu().

The LU decomposition of a matrix A is:

$A = P L U$

where P is a permutation matrix, L is lower-triangular and U is upper-triangular.

Parameters:
• a â€“ array of shape (..., M, N) to decompose.

• permute_l â€“ if True, then permute L and return (P @ L, U) (default: False)

• overwrite_a â€“ not used by JAX

• check_finite â€“ not used by JAX

Returns:

• P is a permutation matrix of shape (..., M, M)

• L is a lower-triangular matrix of shape (... M, K)

• U is an upper-triangular matrix of shape (..., K, N)

with K = min(M, N)

Return type:

A tuple of arrays (P @ L, U) if permute_l is True, else (P, L, U)

See also

Examples

An LU decomposition of a 3x3 matrix:

>>> a = jnp.array([[1., 2., 3.],
...                [5., 4., 2.],
...                [3., 2., 1.]])
>>> P, L, U = jax.scipy.linalg.lu(a)


P is a permutation matrix: i.e. each row and column has a single 1:

>>> P
Array([[0., 1., 0.],
[1., 0., 0.],
[0., 0., 1.]], dtype=float32)


L and U are lower-triangular and upper-triangular matrices:

>>> with jnp.printoptions(precision=3):
...   print(L)
...   print(U)
[[ 1.     0.     0.   ]
[ 0.2    1.     0.   ]
[ 0.6   -0.333  1.   ]]
[[5.    4.    2.   ]
[0.    1.2   2.6  ]
[0.    0.    0.667]]


The original matrix can be reconstructed by multiplying the three together:

>>> a_reconstructed = P @ L @ U
>>> jnp.allclose(a, a_reconstructed)
Array(True, dtype=bool)