jax.scipy.linalg.qr#

jax.scipy.linalg.qr(a, overwrite_a=False, lwork=None, mode='full', pivoting=False, check_finite=True)[source]#

Compute QR decomposition of a matrix.

LAX-backend implementation of scipy.linalg._decomp_qr.qr().

Does not support the Scipy argument check_finite=True, because compiled JAX code cannot perform checks of array values at runtime.

Does not support the Scipy argument overwrite_*=True.

Original docstring below.

Calculate the decomposition A = Q R where Q is unitary/orthogonal and R upper triangular.

Parameters:
  • a ((M, N) array_like) – Matrix to be decomposed

  • mode ({'full', 'r', 'economic', 'raw'}, optional) – Determines what information is to be returned: either both Q and R (‘full’, default), only R (‘r’) or both Q and R but computed in economy-size (‘economic’, see Notes). The final option ‘raw’ (added in SciPy 0.11) makes the function return two matrices (Q, TAU) in the internal format used by LAPACK.

  • pivoting (bool, optional) – Whether or not factorization should include pivoting for rank-revealing qr decomposition. If pivoting, compute the decomposition A P = Q R as above, but where P is chosen such that the diagonal of R is non-increasing.

  • overwrite_a (bool) –

  • lwork (Optional[Any]) –

  • check_finite (bool) –

Return type:

Union[Tuple[Array], Tuple[Array, Array]]

Returns:

  • Q (float or complex ndarray) – Of shape (M, M), or (M, K) for mode='economic'. Not returned if mode='r'.

  • R (float or complex ndarray) – Of shape (M, N), or (K, N) for mode='economic'. K = min(M, N).

  • P (int ndarray) – Of shape (N,) for pivoting=True. Not returned if pivoting=False.