jax.scipy.linalg.lu

Contents

jax.scipy.linalg.lu#

jax.scipy.linalg.lu(a, permute_l=False, overwrite_a=False, check_finite=True)[source]#

Compute LU decomposition of a matrix with partial pivoting.

LAX-backend implementation of scipy.linalg._decomp_lu.lu().

Does not support the Scipy argument check_finite=True, because compiled JAX code cannot perform checks of array values at runtime.

Does not support the Scipy argument overwrite_*=True.

Original docstring below.

The decomposition satisfies:

A = P @ L @ U

where P is a permutation matrix, L lower triangular with unit diagonal elements, and U upper triangular. If permute_l is set to True then L is returned already permuted and hence satisfying A = L @ U.

Parameters:
  • a ((M, N) array_like) – Array to decompose

  • permute_l (bool, optional) – Perform the multiplication P*L (Default: do not permute)

  • overwrite_a (bool, optional) – Whether to overwrite data in a (may improve performance)

  • check_finite (bool, optional) – Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.

  • p_indices (bool, optional) – If True the permutation information is returned as row indices. The default is False for backwards-compatibility reasons.

Return type:

tuple[Array, Array] | tuple[Array, Array, Array]

Returns:

  • **(If permute_l is False)**

  • p ((…, M, M) ndarray) – Permutation arrays or vectors depending on p_indices

  • l ((…, M, K) ndarray) – Lower triangular or trapezoidal array with unit diagonal. K = min(M, N)

  • u ((…, K, N) ndarray) – Upper triangular or trapezoidal array

  • **(If permute_l is True)**

  • pl ((…, M, K) ndarray) – Permuted L matrix. K = min(M, N)

  • u ((…, K, N) ndarray) – Upper triangular or trapezoidal array