jax.lax.linalg.qdwh

Contents

jax.lax.linalg.qdwh#

jax.lax.linalg.qdwh(x, *, is_hermitian=False, max_iterations=None, eps=None, dynamic_shape=None)[source]#

QR-based dynamically weighted Halley iteration for polar decomposition.

Parameters:
  • x – A full-rank matrix, with shape M x N. The matrix may be padded up to that size from a smaller true shape (dynamic_shape).

  • is_hermitian – True if x is Hermitian. Default to False.

  • eps – The final result will satisfy |x_k - x_k-1| < |x_k| * (4*eps)**(1/3) where x_k is the iterate.

  • max_iterations – Iterations will terminate after this many steps even if the above is unsatisfied.

  • dynamic_shape (tuple[int, int] | None) – the unpadded shape as an (m, n) tuple; optional.

Returns:

A four-tuple of (u, h, num_iters, is_converged) containing the polar decomposition of x = u * h, the number of iterations to compute u, and is_converged, whose value is True when the convergence is achieved within the maximum number of iterations.