jax.numpy.linalg.lstsq#
- jax.numpy.linalg.lstsq(a, b, rcond=None, *, numpy_resid=False)[source]#
Return the least-squares solution to a linear equation.
JAX implementation of
numpy.linalg.lstsq()
.- Parameters:
a (ArrayLike) – array of shape
(M, N)
representing the coefficient matrix.b (ArrayLike) – array of shape
(M,)
or(M, K)
representing the right-hand side.rcond (float | None | None) – Cut-off ratio for small singular values. Singular values smaller than
rcond * largest_singular_value
are treated as zero. If None (default), the optimal value will be used to reduce floating point errors.numpy_resid (bool) – If True, compute and return residuals in the same way as NumPy’s linalg.lstsq. This is necessary if you want to precisely replicate NumPy’s behavior. If False (default), a more efficient method is used to compute residuals.
- Returns:
Tuple of arrays
(x, resid, rank, s)
wherex
is a shape(N,)
or(N, K)
array containing the least-squares solution.resid
is the sum of squared residual of shape()
or(K,)
.rank
is the rank of the matrixa
.s
is the singular values of the matrixa
.
- Return type:
Examples
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> b = jnp.array([5, 6]) >>> x, _, _, _ = jnp.linalg.lstsq(a, b) >>> with jnp.printoptions(precision=3): ... print(x) [-4. 4.5]