"""Preconditioned MINRES-QLP on the full saddle-point KKT system."""
from __future__ import annotations
from collections.abc import Callable
import jax
import jax.numpy as jnp
from jaxtyping import Array, Bool, Float, Int
from slsqp_jax.inner.base import AbstractInnerSolver
from slsqp_jax.inner.krylov import pminres_qlp_solve
from slsqp_jax.inner.masking import make_active_subproblem
from slsqp_jax.state import InnerSolveResult
from slsqp_jax.types import Scalar, Vector
from slsqp_jax.utils import to_scalar
def _solve_kkt_minres_qlp(
hvp_fn: Callable[[Vector], Vector],
g: Vector,
A: Float[Array, "m n"],
b: Float[Array, " m"],
active_mask: Bool[Array, " m"],
max_iter: int,
tol: float | Scalar,
precond_fn: Callable[[Vector], Vector] | None = None,
free_mask: Bool[Array, " n"] | None = None,
d_fixed: Vector | None = None,
proj_refine_max_iter: int = 3,
proj_refine_rtol: float = 1e-10,
proj_refine_atol: float = 1e-14,
) -> tuple[
Vector,
Float[Array, " m"],
Bool[Array, ""],
Scalar,
Int[Array, ""],
]:
"""Solve equality-constrained QP via PMINRES-QLP on the full KKT system.
The KKT system::
[B A^T] [d] [-g]
[A 0 ] [lambda] = [b ]
is symmetric indefinite. PMINRES-QLP solves it directly, producing
both d and the Lagrange multipliers lambda.
Uses a block-diagonal SPD preconditioner ``M^{-1} = diag(B_diag^{-1},
S^{-1})`` where ``B_diag^{-1}`` is the user-supplied preconditioner
(typically L-BFGS inverse Hessian diagonal) and ``S = A B_diag^{-1}
A^T`` is the Schur complement. This satisfies the SPD requirement
from Choi (2006, Section 3.4).
"""
n = g.shape[0]
m = A.shape[0]
tol = to_scalar(tol)
sub = make_active_subproblem(
hvp_fn=hvp_fn,
g=g,
A=A,
b=b,
active_mask=active_mask,
free_mask=free_mask,
d_fixed=d_fixed,
)
A_work = sub.A_work
b_work = sub.b_work
hvp_work = sub.hvp_work
g_eff = sub.g_eff
_free = sub.free_mask
_dfixed = sub.d_fixed
has_fixed = sub.has_fixed
# KKT operator on (n+m)-dimensional vectors [d; lambda]
def kkt_matvec(z: Vector) -> Vector:
d_part = z[:n]
lam_part = z[n:]
top = hvp_work(d_part) + A_work.T @ lam_part
bot = A_work @ d_part
return jnp.concatenate([top, bot])
kkt_rhs = jnp.concatenate([-g_eff, b_work])
# Inactive constraint rows are zeroed in A_work / b_work; the
# range-space and Schur factorisations need a "1" on those diagonal
# positions to stay invertible without coupling into the active
# block. Hoisted out of the preconditioner branch so the no-precond
# path can also reuse it for the posterior projection.
reg_diag = jnp.where(active_mask, 0.0, 1.0)
if precond_fn is not None:
_raw_precond = precond_fn
# When free_mask is active, mask the primal block so L-BFGS
# cross-coupling does not leak non-zero values into the
# zero-row/column dimensions.
if has_fixed:
_free_f = _free.astype(g.dtype)
def _primal_precond(v: Vector) -> Vector:
return _free_f * _raw_precond(_free_f * v)
else:
_primal_precond = _raw_precond
# Schur complement S = A M^{-1} A^T (m x m, SPD)
M_AT = jax.vmap(_primal_precond)(A_work).T # (n, m)
A_M_AT = A_work @ M_AT + jnp.diag(reg_diag) + 1e-8 * jnp.eye(m)
A_M_AT_chol = jnp.linalg.cholesky(A_M_AT)
def _solve_schur(rhs_s: Float[Array, " m"]) -> Float[Array, " m"]:
return jax.scipy.linalg.cho_solve((A_M_AT_chol, True), rhs_s)
def kkt_precond(z: Vector) -> Vector:
r1 = z[:n]
r2 = z[n:]
v1 = _primal_precond(r1)
v2 = _solve_schur(r2)
return jnp.concatenate([v1, v2])
solution, converged = pminres_qlp_solve(
kkt_matvec, kkt_rhs, tol=tol, max_iter=max_iter, precond=kkt_precond
)
# M-metric range-space projection: minimise ||δd||_{M^{-1}} s.t.
# A_work (d - δd) = b_work.
def _project_step(d_in: Vector) -> tuple[Vector, Scalar]:
r_dual = jnp.where(active_mask, A_work @ d_in - b_work, 0.0)
r_norm = jnp.linalg.norm(r_dual)
delta_lambda = _solve_schur(r_dual)
delta_d = _primal_precond(A_work.T @ delta_lambda)
return d_in - delta_d, r_norm
else:
solution, converged = pminres_qlp_solve(
kkt_matvec, kkt_rhs, tol=tol, max_iter=max_iter
)
# Build a small dedicated m x m Cholesky just for the posterior
# 2-norm projection.
A_AT = A_work @ A_work.T + jnp.diag(reg_diag) + 1e-8 * jnp.eye(m)
A_AT_chol = jnp.linalg.cholesky(A_AT)
def _project_step(d_in: Vector) -> tuple[Vector, Scalar]:
r_dual = jnp.where(active_mask, A_work @ d_in - b_work, 0.0)
r_norm = jnp.linalg.norm(r_dual)
delta_lambda = jax.scipy.linalg.cho_solve((A_AT_chol, True), r_dual)
return d_in - A_work.T @ delta_lambda, r_norm
# Iterative refinement of the projection (HR 2014, Algorithm 4.18,
# step 1(a)). Each round squares the relative feasibility error.
b_norm_floor = jnp.linalg.norm(b_work) + jnp.asarray(1.0, dtype=b_work.dtype)
proj_atol = jnp.asarray(proj_refine_atol, dtype=b_work.dtype)
proj_rtol = jnp.asarray(proj_refine_rtol, dtype=b_work.dtype)
refine_target = proj_atol + proj_rtol * b_norm_floor
d_proj, residual_pre = _project_step(solution[:n])
n_refinements = jnp.asarray(0)
def _refine_body(carry, _):
d_cur, _r_prev, done_prev, n_done = carry
d_next, r_next = _project_step(d_cur)
d_out = jnp.where(done_prev, d_cur, d_next)
r_out = jnp.where(done_prev, _r_prev, r_next)
n_out = jnp.where(done_prev, n_done, n_done + 1)
done_next = done_prev | (r_out <= refine_target)
return (d_out, r_out, done_next, n_out), r_out
if proj_refine_max_iter > 0:
residual_init = jnp.linalg.norm(
jnp.where(active_mask, A_work @ d_proj - b_work, 0.0)
)
done_init = residual_init <= refine_target
(d_proj, residual_post, _done_final, n_refinements), _ = jax.lax.scan(
_refine_body,
(d_proj, residual_init, done_init, n_refinements),
None,
length=proj_refine_max_iter,
)
else:
residual_post = jnp.linalg.norm(
jnp.where(active_mask, A_work @ d_proj - b_work, 0.0)
)
del residual_pre
d = d_proj
if has_fixed:
# Force the direction to respect the fixed mask.
d = _free * d + _dfixed
multipliers = -solution[n:]
multipliers = jnp.where(active_mask, multipliers, 0.0)
finite = jnp.isfinite(d).all() & jnp.isfinite(multipliers).all()
return d, multipliers, converged & finite, residual_post, n_refinements
[docs]
class MinresQLPSolver(AbstractInnerSolver):
"""Preconditioned MINRES-QLP on the full saddle-point KKT system.
Solves the KKT system directly::
[B A^T] [d] [-g]
[A 0 ] [lambda] = [b ]
using PMINRES-QLP (Choi, Paige & Saunders, SISC 2011, Table 3.5)
with a block-diagonal SPD preconditioner::
M = [B_diag^{-1} 0 ]
[0 S^{-1} ]
where ``B_diag = diag(B_0)`` (L-BFGS diagonal) and
``S = A B_diag^{-1} A^T`` is the Schur complement.
After PMINRES-QLP returns the iterate ``d``, an M-metric range-space
projection drives ``A d = b`` on the active rows. The single shot
is followed by up to ``proj_refine_max_iter`` rounds of iterative
refinement, each costing one matvec + one Schur back-solve (no
refactorisation). Refinement squares the relative feasibility
error per round. See HR (2014, Algorithm 4.18 step 1(a)) for the
motivation.
"""
max_iter: int = 200
tol: float = 1e-10
max_cg_iter: int = 50
# Iterative refinement of the M-metric projection. See class docstring.
proj_refine_max_iter: int = 3
proj_refine_rtol: float = 1e-10
proj_refine_atol: float = 1e-14
[docs]
def solve(
self,
hvp_fn: Callable[[Vector], Vector],
g: Vector,
A: Float[Array, "m n"],
b: Float[Array, " m"],
active_mask: Bool[Array, " m"],
precond_fn: Callable[[Vector], Vector] | None = None,
free_mask: Bool[Array, " n"] | None = None,
d_fixed: Vector | None = None,
adaptive_tol: Scalar | float | None = None,
) -> InnerSolveResult:
# MINRES-QLP solves the full KKT system where constraints A d = b
# are part of the linear system. Loosening the tolerance would
# degrade constraint satisfaction (unlike null-space CG where
# constraints are enforced by the projector). Keep self.tol.
d, multipliers, converged, proj_residual, n_refinements = _solve_kkt_minres_qlp(
hvp_fn=hvp_fn,
g=g,
A=A,
b=b,
active_mask=active_mask,
max_iter=self.max_iter,
tol=self.tol,
precond_fn=precond_fn,
free_mask=free_mask,
d_fixed=d_fixed,
proj_refine_max_iter=self.proj_refine_max_iter,
proj_refine_rtol=self.proj_refine_rtol,
proj_refine_atol=self.proj_refine_atol,
)
return InnerSolveResult(
d=d,
multipliers=multipliers,
converged=converged,
proj_residual=proj_residual,
n_proj_refinements=n_refinements,
projected_grad_norm=jnp.asarray(jnp.inf, dtype=d.dtype),
)
__all__ = ["MinresQLPSolver"]