jax.numpy.linalg.lstsq

jax.numpy.linalg.lstsq(a, b, rcond=None, *, numpy_resid=False)[source]

Return the least-squares solution to a linear matrix equation.

LAX-backend implementation of lstsq().

It has two important differences:

  1. In numpy.linalg.lstsq, the default rcond is -1, and warns that in the future the default will be None. Here, the default rcond is None.

  2. In np.linalg.lstsq the returned residuals are empty for low-rank or over-determined solutions. Here, the residuals are returned in all cases, to make the function compatible with jit. The non-jit compatible numpy behavior can be recovered by passing numpy_resid=True.

The lstsq function does not currently have a custom JVP rule, so the gradient is poorly behaved for some inputs, particularly for low-rank a.

Original docstring below.

Computes the vector x that approximatively solves the equation a @ x = b. The equation may be under-, well-, or over-determined (i.e., the number of linearly independent rows of a can be less than, equal to, or greater than its number of linearly independent columns). If a is square and of full rank, then x (but for round-off error) is the “exact” solution of the equation. Else, x minimizes the Euclidean 2-norm \(|| b - a x ||\).

Parameters
  • a ((M, N) array_like) – “Coefficient” matrix.

  • b ({(M,), (M, K)} array_like) – Ordinate or “dependent variable” values. If b is two-dimensional, the least-squares solution is calculated for each of the K columns of b.

  • rcond (float, optional) –

    Cut-off ratio for small singular values of a. For the purposes of rank determination, singular values are treated as zero if they are smaller than rcond times the largest singular value of a.

    Changed in version 1.14.0: If not set, a FutureWarning is given. The previous default of -1 will use the machine precision as rcond parameter, the new default will use the machine precision times max(M, N). To silence the warning and use the new default, use rcond=None, to keep using the old behavior, use rcond=-1.

Returns

  • x ({(N,), (N, K)} ndarray) – Least-squares solution. If b is two-dimensional, the solutions are in the K columns of x.

  • residuals ({(1,), (K,), (0,)} ndarray) – Sums of residuals; squared Euclidean 2-norm for each column in b - a*x. If the rank of a is < N or M <= N, this is an empty array. If b is 1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,).

  • rank (int) – Rank of matrix a.

  • s ((min(M, N),) ndarray) – Singular values of a.