# Source code for jax._src.lax.qdwh

"""A JIT-compatible library for QDWH-based polar decomposition.

QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
iteration implemented through QR decmopositions does not require matrix
inversion. This is desirable for multicore and heterogeneous computing systems.

Reference: Nakatsukasa, Yuji, Zhaojun Bai, and FranÃ§ois Gygi.
"Optimizing Halley's iteration for computing the matrix polar decomposition."
SIAM Journal on Matrix Analysis and Applications 31, no. 5 (2010): 2700-2720.
https://epubs.siam.org/doi/abs/10.1137/090774999
"""

import functools
from typing import Optional, Tuple

import jax
import jax.numpy as jnp
from jax import lax
from jax._src import core
from jax._src.lax import linalg as lax_linalg

# Helpers for working with padded shapes
"""Masks `x` up to the dynamic shape `dims`.

Replaces values outside those dimensions with `alternative`. `alternative` is
"""
assert jnp.ndim(x) == len(dims)
for i, d in enumerate(dims):
if d is not None:

def _pad_in_dim(x, low=0, high=0, interior=0, fill_value=0, axis=0):
pads = [(0, 0, 0)] * x.ndim

def _dynamic_concat(a, b, m, axis=0):
"Concatenates padded arrays `a` and `b` where the true size of `a` is `m`."
if m is None:
return jnp.concatenate([a, b], axis=axis)
return lax.dynamic_update_slice_in_dim(
_pad_in_dim(a, high=b.shape[axis], axis=axis), b, m, axis)

def _use_qr(u, m, n, params):
"""QDWH iteration using QR decomposition.

Args:
u: a matrix, with static (padded) shape M x N.
m, n: the dynamic shape of the matrix, where m <= M and n <= N.
params: the QDWH parameters.
"""
a, b, c = params
M, N = u.shape

y = _dynamic_concat(jnp.sqrt(c) * u, jnp.eye(N, dtype=jnp.dtype(u)), m)
q, _ = lax_linalg.qr(y, full_matrices=False)
# q1 = q[:m, :]
q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n))
# q2 = (q[m:, :]).T.conj()
q2 = lax.dynamic_slice_in_dim(q, m, N, axis=0)
e = b / c
u = (e * u + (a - e) / jnp.sqrt(c) * jnp.einsum('ij,jk->ik', q1, q2))
return u

def _use_cholesky(u, m, n, params):
"""QDWH iteration using Cholesky decomposition.

Args:
u: a matrix, with static (padded) shape M x N
m, n: the dynamic shape of the matrix, where m <= M and n <= N.
params: the QDWH parameters.
"""
a, b, c = params
_, N = u.shape
x = c * (u.T.conj() @ u) + jnp.eye(N, dtype=jnp.dtype(u))
# Pads the lower-right corner with the identity matrix to prevent the Cholesky
# decomposition from failing due to the matrix not being PSD if padded with
# zeros.
x = _mask(x, (n, n), jnp.eye(N, dtype=x.dtype))

# `y` is lower triangular.
y = lax_linalg.cholesky(x, symmetrize_input=False)

z = lax_linalg.triangular_solve(
y, u.T, left_side=True, lower=True, conjugate_a=True).conj()

z = lax_linalg.triangular_solve(y, z, left_side=True, lower=True,
transpose_a=True, conjugate_a=True).T.conj()

e = b / c
u = e * u + (a - e) * z
return u

def _qdwh(x, m, n, is_hermitian, max_iterations, eps):
"""QR-based dynamically weighted Halley iteration for polar decomposition."""

# Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of
# norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for
# the smallest singular value of x.
if eps is None:
eps = float(jnp.finfo(x.dtype).eps)
alpha = (jnp.sqrt(jnp.linalg.norm(x, ord=1)) *
jnp.sqrt(jnp.linalg.norm(x, ord=jnp.inf))).astype(x.dtype)
l = eps

u = x / alpha

# Iteration tolerances.
tol_l = 10.0 * eps / 2.0
tol_norm = jnp.cbrt(tol_l)

def cond_fun(state):
_, _, _, is_unconverged, is_not_max_iteration = state
return jnp.logical_and(is_unconverged, is_not_max_iteration)

def body_fun(state):
u, l, iter_idx, _, _ = state

u_prev = u

# Computes parameters.
l2 = l**2
dd = jnp.cbrt(4.0 * (1.0 / l2 - 1.0) / l2)
sqd = jnp.sqrt(1.0 + dd)
a = (sqd + jnp.sqrt(8.0 - 4.0 * dd + 8.0 * (2.0 - l2) / (l2 * sqd)) / 2)
a = jnp.real(a)
b = (a - 1.0)**2 / 4.0
c = a + b - 1.0

l = l * (a + b * l2) / (1.0 + c * l2)

# Uses QR or Cholesky decomposition.
def true_fn(u):
return _use_qr(u, m, n, params=(a, b, c))

def false_fn(u):
return _use_cholesky(u, m, n, params=(a, b, c))

u = jax.lax.cond(c > 100, true_fn, false_fn, operand=(u))

if is_hermitian:
u = (u + u.T.conj()) / 2.0

# Checks convergence.
iterating_l = jnp.abs(1.0 - l) > tol_l
iterating_u = jnp.linalg.norm(u-u_prev) > tol_norm
is_unconverged = jnp.logical_or(iterating_l, iterating_u)

is_not_max_iteration = iter_idx < max_iterations

return u, l, iter_idx + 1, is_unconverged, is_not_max_iteration

iter_idx = 1
is_unconverged = True
is_not_max_iteration = True
u, _, num_iters, is_unconverged, _ = jax.lax.while_loop(
cond_fun=cond_fun, body_fun=body_fun,
init_val=(u, l, iter_idx, is_unconverged, is_not_max_iteration))

# Applies Newton-Schulz refinement for better accuracy.
u = 1.5 * u - 0.5 * u @ (u.T.conj() @ u)

h = u.T.conj() @ x
h = (h + h.T.conj()) / 2.0

# Converged within the maximum number of iterations.
is_converged = jnp.logical_not(is_unconverged)

return u, h, num_iters - 1, is_converged

[docs]@functools.partial(jax.jit, static_argnames=('is_hermitian',))
def qdwh(x, *, is_hermitian=False, max_iterations=None, eps=None,
dynamic_shape: Optional[Tuple[int, int]] = None):
"""QR-based dynamically weighted Halley iteration for polar decomposition.

Args:
x: A full-rank matrix, with shape `M x N`. The matrix may be
padded up to that size from a smaller true shape (``dynamic_shape``).
is_hermitian: True if `x` is Hermitian. Default to `False`.
eps: The final result will satisfy
``|x_k - x_k-1| < |x_k| * (4*eps)**(1/3)`` where `x_k` is the iterate.
max_iterations: Iterations will terminate after this many steps even if the
above is unsatisfied.
dynamic_shape: the unpadded shape as an ``(m, n)`` tuple; optional.

Returns:
A four-tuple of (u, h, num_iters, is_converged) containing the
polar decomposition of `x = u * h`, the number of iterations to compute `u`,
and `is_converged`, whose value is `True` when the convergence is achieved
within the maximum number of iterations.
"""
is_hermitian = core.concrete_or_error(
bool, is_hermitian, 'The `is_hermitian` argument must be statically '
'specified to use `qdwh` within JAX transformations.')

if max_iterations is None:
max_iterations = 10

M, N = x.shape
if M < N:
raise ValueError('The input matrix of shape M x N must have M >= N.')
if dynamic_shape is not None:
m, n = dynamic_shape