jax.numpy.linalg.lstsq

Contents

jax.numpy.linalg.lstsq#

jax.numpy.linalg.lstsq(a, b, rcond=None, *, numpy_resid=False)[source]#

Return the least-squares solution to a linear equation.

JAX implementation of numpy.linalg.lstsq().

Parameters:
  • a (jax.typing.ArrayLike) – array of shape (M, N) representing the coefficient matrix.

  • b (jax.typing.ArrayLike) – array of shape (M,) or (M, K) representing the right-hand side.

  • rcond (float | None) – Cut-off ratio for small singular values. Singular values smaller than rcond * largest_singular_value are treated as zero. If None (default), the optimal value will be used to reduce floating point errors.

  • numpy_resid (bool) – If True, compute and return residuals in the same way as NumPy’s linalg.lstsq. This is necessary if you want to precisely replicate NumPy’s behavior. If False (default), a more efficient method is used to compute residuals.

Returns:

Tuple of arrays (x, resid, rank, s) where

  • x is a shape (N,) or (N, K) array containing the least-squares solution.

  • resid is the sum of squared residual of shape () or (K,).

  • rank is the rank of the matrix a.

  • s is the singular values of the matrix a.

Return type:

tuple[Array, Array, Array, Array]

Example

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> b = jnp.array([5, 6])
>>> x, _, _, _ = jnp.linalg.lstsq(a, b)
>>> with jnp.printoptions(precision=3):
...   print(x)
[-4.   4.5]