jax.scipy.linalg.lu#

jax.scipy.linalg.lu(a, permute_l=False, overwrite_a=False, check_finite=True)[source]#

Compute pivoted LU decomposition of a matrix.

LAX-backend implementation of scipy.linalg._decomp_lu.lu().

Does not support the Scipy argument check_finite=True, because compiled JAX code cannot perform checks of array values at runtime.

Does not support the Scipy argument overwrite_*=True.

Original docstring below.

The decomposition is:

A = P L U

where P is a permutation matrix, L lower triangular with unit diagonal elements, and U upper triangular.

Parameters:
  • a ((M, N) array_like) – Array to decompose

  • permute_l (bool, optional) – Perform the multiplication P*L (Default: do not permute)

  • overwrite_a (bool, optional) – Whether to overwrite data in a (may improve performance)

  • check_finite (bool, optional) – Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.

Return type:

Union[Tuple[Array, Array], Tuple[Array, Array, Array]]

Returns:

  • **(If permute_l == False)**

  • p ((M, M) ndarray) – Permutation matrix

  • l ((M, K) ndarray) – Lower triangular or trapezoidal matrix with unit diagonal. K = min(M, N)

  • u ((K, N) ndarray) – Upper triangular or trapezoidal matrix

  • **(If permute_l == True)**

  • pl ((M, K) ndarray) – Permuted L matrix. K = min(M, N)

  • u ((K, N) ndarray) – Upper triangular or trapezoidal matrix