jax.numpy.linalg.svd

Contents

jax.numpy.linalg.svd#

jax.numpy.linalg.svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[True], hermitian: bool = False, subset_by_index: tuple[int, int] | None = None) SVDResult[source]#
jax.numpy.linalg.svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[True], hermitian: bool = False, subset_by_index: tuple[int, int] | None = None) SVDResult
jax.numpy.linalg.svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False], hermitian: bool = False, subset_by_index: tuple[int, int] | None = None) Array
jax.numpy.linalg.svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False], hermitian: bool = False, subset_by_index: tuple[int, int] | None = None) Array
jax.numpy.linalg.svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False, subset_by_index: tuple[int, int] | None = None) Array | SVDResult

Compute the singular value decomposition.

JAX implementation of numpy.linalg.svd(), implemented in terms of jax.lax.linalg.svd().

The SVD of a matrix A is given by

\[A = U\Sigma V^H\]
  • \(U\) contains the left singular vectors and satisfies \(U^HU=I\)

  • \(V\) contains the right singular vectors and satisfies \(V^HV=I\)

  • \(\Sigma\) is a diagonal matrix of singular values.

Parameters:
  • a – input array, of shape (..., N, M)

  • full_matrices – if True (default) compute the full matrices; i.e. u and vh have shape (..., N, N) and (..., M, M). If False, then the shapes are (..., N, K) and (..., K, M) with K = min(N, M).

  • compute_uv – if True (default), return the full SVD (u, s, vh). If False then return only the singular values s.

  • hermitian – if True, assume the matrix is hermitian, which allows for a more efficient implementation (default=False)

  • subset_by_index – (TPU-only) Optional 2-tuple [start, end] indicating the range of indices of singular values to compute. For example, if [n-2, n] then svd computes the two largest singular values and their singular vectors. Only compatible with full_matrices=False.

Returns:

A tuple of arrays (u, s, vh) if compute_uv is True, otherwise the array s.

  • u: left singular vectors of shape (..., N, N) if full_matrices is True or (..., N, K) otherwise.

  • s: singular values of shape (..., K)

  • vh: conjugate-transposed right singular vectors of shape (..., M, M) if full_matrices is True or (..., K, M) otherwise.

where K = min(N, M).

See also

Example

Consider the SVD of a small real-valued array:

>>> x = jnp.array([[1., 2., 3.],
...                [6., 5., 4.]])
>>> u, s, vt = jnp.linalg.svd(x, full_matrices=False)
>>> s  
Array([9.361919 , 1.8315067], dtype=float32)

The singular vectors are in the columns of u and v = vt.T. These vectors are orthonormal, which can be demonstrated by comparing the matrix product with the identity matrix:

>>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)
>>> v = vt.T
>>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)

Given the SVD, x can be reconstructed via matrix multiplication:

>>> x_reconstructed = u @ jnp.diag(s) @ vt
>>> jnp.allclose(x_reconstructed, x)
Array(True, dtype=bool)