jax.scipy.linalg.lu

Contents

jax.scipy.linalg.lu#

jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array, Array][source]#
jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array]
jax.scipy.linalg.lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array] | tuple[Array, Array, Array]

Compute the LU decomposition

JAX implementation of scipy.linalg.lu().

The LU decomposition of a matrix A is:

\[A = P L U\]

where P is a permutation matrix, L is lower-triangular and U is upper-triangular.

Parameters:
  • a – array of shape (..., M, N) to decompose.

  • permute_l – if True, then permute L and return (P @ L, U) (default: False)

  • overwrite_a – not used by JAX

  • check_finite – not used by JAX

Returns:

  • P is a permutation matrix of shape (..., M, M)

  • L is a lower-triangular matrix of shape (... M, K)

  • U is an upper-triangular matrix of shape (..., K, N)

with K = min(M, N)

Return type:

A tuple of arrays (P @ L, U) if permute_l is True, else (P, L, U)

See also

Example

An LU decomposition of a 3x3 matrix:

>>> a = jnp.array([[1., 2., 3.],
...                [5., 4., 2.],
...                [3., 2., 1.]])
>>> P, L, U = jax.scipy.linalg.lu(a)

P is a permutation matrix: i.e. each row and column has a single 1:

>>> P
Array([[0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.]], dtype=float32)

L and U are lower-triangular and upper-triangular matrices:

>>> with jnp.printoptions(precision=3):
...   print(L)
...   print(U)
[[ 1.     0.     0.   ]
 [ 0.2    1.     0.   ]
 [ 0.6   -0.333  1.   ]]
[[5.    4.    2.   ]
 [0.    1.2   2.6  ]
 [0.    0.    0.667]]

The original matrix can be reconstructed by multiplying the three together:

>>> a_reconstructed = P @ L @ U
>>> jnp.allclose(a, a_reconstructed)
Array(True, dtype=bool)