# jax.scipy.linalg.polar#

jax.scipy.linalg.polar(a, side='right', *, method='qdwh', eps=None, max_iterations=None)[source]#

Computes the polar decomposition.

Given the $$m \times n$$ matrix $$a$$, returns the factors of the polar decomposition $$u$$ (also $$m \times n$$) and $$p$$ such that $$a = up$$ (if side is "right"; $$p$$ is $$n \times n$$) or $$a = pu$$ (if side is "left"; $$p$$ is $$m \times m$$), where $$p$$ is positive semidefinite. If $$a$$ is nonsingular, $$p$$ is positive definite and the decomposition is unique. $$u$$ has orthonormal columns unless $$n > m$$, in which case it has orthonormal rows.

Writing the SVD of $$a$$ as $$a = u_\mathit{svd} \cdot s_\mathit{svd} \cdot v^h_\mathit{svd}$$, we have $$u = u_\mathit{svd} \cdot v^h_\mathit{svd}$$. Thus the unitary factor $$u$$ can be constructed as the application of the sign function to the singular values of $$a$$; or, if $$a$$ is Hermitian, the eigenvalues.

Several methods exist to compute the polar decomposition. Currently two are supported:

• method="svd":

Computes the SVD of $$a$$ and then forms $$u = u_\mathit{svd} \cdot v^h_\mathit{svd}$$.

• method="qdwh":

Applies the QDWH (QR-based Dynamically Weighted Halley) algorithm.

Parameters:
• a (jax.typing.ArrayLike) â€“ The $$m \times n$$ input matrix.

• side (str) â€“ Determines whether a right or left polar decomposition is computed. If side is "right" then $$a = up$$. If side is "left" then $$a = pu$$. The default is "right".

• method (str) â€“ Determines the algorithm used, as described above.

• precision â€“ Precision object specifying the matmul precision.

• eps (float | None) â€“ The final result will satisfy $$\left|x_k - x_{k-1}\right| < \left|x_k\right| (4\epsilon)^{\frac{1}{3}}$$, where $$x_k$$ are the QDWH iterates. Ignored if method is not "qdwh".

• max_iterations (int | None) â€“ Iterations will terminate after this many steps even if the above is unsatisfied. Ignored if method is not "qdwh".

Returns:

A (unitary, posdef) tuple, where unitary is the unitary factor ($$m \times n$$), and posdef is the positive-semidefinite factor. posdef is either $$n \times n$$ or $$m \times m$$ depending on whether side is "right" or "left", respectively.

Return type:

Example

Polar decomposition of a 3x3 matrix:

>>> a = jnp.array([[1., 2., 3.],
...                [5., 4., 2.],
...                [3., 2., 1.]])
>>> U, P = jax.scipy.linalg.polar(a)


U is a Unitary Matrix:

>>> jnp.round(U.T @ U)
Array([[ 1., -0., -0.],
[-0.,  1.,  0.],
[-0.,  0.,  1.]], dtype=float32)


P is positive-semidefinite Matrix:

>>> with jnp.printoptions(precision=2, suppress=True):
...     print(P)
[[4.79 3.25 1.23]
[3.25 3.06 2.01]
[1.23 2.01 2.91]]


The original matrix can be reconstructed by multiplying the U and P:

>>> a_reconstructed = U @ P
>>> jnp.allclose(a, a_reconstructed)
Array(True, dtype=bool)