"""Krylov primitives shared by the inner solvers.
Hosts the building blocks used by every projector-based and saddle-point
inner solver:
* (Preconditioned, optionally projected) conjugate gradient via
:func:`build_cg_step` and the unconstrained driver
:func:`solve_unconstrained_cg`.
* CRAIG's method for ``min ||x|| s.t. A x = rhs`` via :func:`craig_solve`.
* Stable symmetric Givens rotation :func:`_sym_ortho` and the full
Preconditioned MINRES-QLP iteration via :func:`pminres_qlp_solve`.
These were previously defined inline inside ``slsqp_jax/inner_solver.py``
mixed with the solver classes. Splitting them out keeps each Krylov
recurrence in a single, navigable module and makes the high-level
solver classes easier to read.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import NamedTuple
import jax
import jax.numpy as jnp
from jaxtyping import Array, Bool, Float, Int
from slsqp_jax.types import Scalar, Vector
from slsqp_jax.utils import to_scalar
# ---------------------------------------------------------------------------
# (Preconditioned, projected) conjugate gradient
# ---------------------------------------------------------------------------
class _CGState(NamedTuple):
"""Internal state for the (preconditioned) conjugate gradient solver.
When a preconditioner M is used, ``rz`` stores r^T z (where z = M r)
instead of r^T r, and ``p`` is built from z rather than r.
"""
d: Vector
r: Vector
p: Vector
rz: Scalar # r^T z (preconditioned) or r^T r (unpreconditioned)
iteration: Int[Array, ""]
converged: Bool[Array, ""]
[docs]
def build_cg_step(
hvp_fn,
cg_tol: Scalar | float,
precond_fn: Callable[[Vector], Vector] | None = None,
project: Callable[[Vector], Vector] | None = None,
cg_regularization: float = 1e-6,
cg_atol: float | None = None,
):
"""Build a CG step function.
Args:
hvp_fn: Hessian-vector product function v -> B @ v.
cg_tol: Convergence tolerance on residual norm (absolute when
``cg_atol`` is ``None``; otherwise the squared per-step test
is ``r^T r < max(cg_atol**2, cg_tol**2)`` so the larger of
the two acts as the effective floor).
precond_fn: Optional preconditioner v -> M @ v where M ~ B^{-1}.
project: Optional projection function v -> P(v) where P is the
projection onto the null space of A.
cg_regularization: Minimum eigenvalue threshold for the curvature
guard.
cg_atol: Optional absolute residual-norm floor. When provided,
convergence is declared at ``r^T r < max(cg_atol**2,
cg_tol**2)``. Used by the multiplier-recovery CG inside
``ProjectedCGCraig`` so that a near-KKT iterate does not
chase a relative target below ``eps``. Defaults to ``None``
(current behaviour: pure absolute test ``r^T r < cg_tol**2``).
Returns:
A CG step function.
"""
if project is None:
def project(v: Vector) -> Vector:
return v
# Coerce ``cg_tol`` to a true 0-d scalar. See the long comment in
# the legacy implementation for why this matters under JIT tracing.
cg_tol = to_scalar(cg_tol)
if cg_atol is not None:
tol_sq: Scalar | float = jnp.maximum(cg_atol**2, cg_tol**2)
else:
tol_sq = cg_tol**2
def cg_step(i, state):
def do_step(state: _CGState) -> _CGState:
Bp = hvp_fn(state.p)
PBp = project(Bp)
pPBp = jnp.dot(state.p, PBp)
# SNOPT-style curvature guard (see unconstrained CG for detail).
# For projected CG, p is in the null space, so this checks the
# effective eigenvalue of the reduced Hessian Z^T B Z along p.
pp = jnp.dot(state.p, state.p)
has_bad_curvature = pPBp <= cg_regularization * pp
alpha = jnp.where(
has_bad_curvature,
jnp.array(0.0),
state.rz / jnp.maximum(pPBp, 1e-30),
)
d_new = state.d + alpha * state.p
r_new = state.r - alpha * PBp
r_new_norm_sq = jnp.dot(r_new, r_new)
if precond_fn is not None:
z_new_raw = project(precond_fn(r_new))
rz_raw = jnp.dot(r_new, z_new_raw)
z_new = jnp.where(rz_raw > 0, z_new_raw, r_new)
rz_new = jnp.where(rz_raw > 0, rz_raw, r_new_norm_sq)
else:
z_new = r_new
rz_new = r_new_norm_sq
beta = rz_new / jnp.maximum(state.rz, 1e-30)
p_new = z_new + beta * state.p
converged = (r_new_norm_sq < tol_sq) | has_bad_curvature
return jax.lax.cond(
has_bad_curvature,
lambda: _CGState(
d=state.d,
r=state.r,
p=state.p,
rz=state.rz,
iteration=state.iteration + 1,
converged=jnp.array(True),
),
lambda: _CGState(
d=d_new,
r=r_new,
p=p_new,
rz=rz_new,
iteration=state.iteration + 1,
converged=converged,
),
)
# Defensive scalarisation of the predicate.
converged_pred = jnp.reshape(state.converged, ())
return jax.lax.cond(converged_pred, lambda s: s, do_step, state)
return cg_step
[docs]
def solve_unconstrained_cg(
hvp_fn: Callable[[Vector], Vector],
g: Vector,
max_cg_iter: int,
cg_tol: Scalar | float,
precond_fn: Callable[[Vector], Vector] | None = None,
cg_regularization: float = 1e-6,
cg_atol: float | None = None,
) -> tuple[Vector, Bool[Array, ""]]:
"""Solve the unconstrained QP: min (1/2) d^T B d + g^T d.
Uses (preconditioned) conjugate gradient to solve B d = -g without
forming B. When *precond_fn* is provided, the standard PCG algorithm
is used: z = M r, beta = r_new^T z_new / r_old^T z_old, and p is
built from z (Nocedal & Wright, Algorithm 5.3).
Args:
hvp_fn: Hessian-vector product function v -> B @ v.
g: Linear term (gradient).
max_cg_iter: Maximum CG iterations.
cg_tol: Convergence tolerance on residual norm.
precond_fn: Optional preconditioner v -> M @ v where M ~ B^{-1}.
cg_regularization: Minimum eigenvalue threshold for the curvature
guard. CG declares "bad curvature" when the effective
eigenvalue ``p^T B p / ||p||^2`` falls below this value.
Based on SNOPT Section 4.5 (Gill, Murray & Saunders, 2005).
cg_atol: Optional absolute residual-norm floor. When provided,
the per-step convergence test becomes
``r^T r < max(cg_atol**2, cg_tol**2)`` so an absolute floor
kicks in whenever the user-supplied ``cg_tol`` would target
a residual below the floor.
Returns:
Tuple of (d, converged) where d is the solution vector and
converged indicates whether CG converged (residual below
tolerance) as opposed to hitting bad curvature or exhausting
the iteration budget.
"""
# Sign convention: r = b - Ax = -g - Bd (descent residual).
n = g.shape[0]
cg_tol = to_scalar(cg_tol)
if cg_atol is not None:
init_tol_sq: Scalar | float = jnp.maximum(cg_atol**2, cg_tol**2)
else:
init_tol_sq = cg_tol**2
r0 = -g
r0_norm_sq = jnp.dot(r0, r0)
if precond_fn is not None:
z0 = precond_fn(r0)
rz0_raw = jnp.dot(r0, z0)
z0 = jnp.where(rz0_raw > 0, z0, r0)
rz0 = jnp.where(rz0_raw > 0, rz0_raw, r0_norm_sq)
p0 = z0
else:
rz0 = r0_norm_sq
p0 = r0
init_cg = _CGState(
d=jnp.zeros(n),
r=r0,
p=p0,
rz=rz0,
iteration=jnp.array(0),
converged=jnp.reshape(r0_norm_sq < init_tol_sq, ()),
)
cg_step = build_cg_step(
hvp_fn=hvp_fn,
cg_tol=cg_tol,
precond_fn=precond_fn,
cg_regularization=cg_regularization,
cg_atol=cg_atol,
)
final_cg = jax.lax.fori_loop(0, max_cg_iter, cg_step, init_cg)
return final_cg.d, final_cg.converged
# ---------------------------------------------------------------------------
# CRAIG (Golub-Kahan bidiagonalization)
# ---------------------------------------------------------------------------
class _CraigState(NamedTuple):
"""Internal state for the CRAIG iterative solver."""
x: Vector # primal solution (n,): x_k = A^T (A A^T)^{-1} rhs
s: Scalar # coefficient: s_k = (-1)^{k-1} prod(beta_i/alpha_i)
u: Float[Array, " m"] # left bidiag vector (m,)
v: Vector # right bidiag vector (n,)
alpha: Scalar # current alpha
beta: Scalar # current beta (beta_{k+1})
residual: Scalar # |beta_{k+1} * s_k| (||A x_k - rhs||)
converged: Bool[Array, ""]
breakdown: Bool[Array, ""]
iteration: Int[Array, ""]
_CRAIG_BREAKDOWN_TOL = 1e-14
# Hardcoded absolute residual floor for the CRAIG convergence test. Together
# with the user-tunable relative ``tol``, the test becomes
# ``residual < max(_CRAIG_TOL_ABS, tol * ||rhs||)``. Without this floor a
# near-KKT iterate with ``||rhs|| ~ 1e-13`` would target ``tol * 1e-13``
# (below ``eps``) and never converge, even though the actual residual is
# already at machine precision.
_CRAIG_TOL_ABS = 1e-12
[docs]
def craig_solve(
A: Float[Array, "m n"],
rhs: Float[Array, " m"],
tol: float | Scalar = 1e-10,
max_iter: int = 100,
) -> tuple[Vector, Bool[Array, ""]]:
"""Solve ``min ||x|| s.t. A x = rhs`` via CRAIG's method.
CRAIG's method (Paige & Saunders, 1982) uses Golub-Kahan
bidiagonalization to solve the minimum-norm problem without forming
``A A^T``. Only matrix-vector products ``A @ v`` and ``A.T @ u``
are needed.
Convergence test is hybrid absolute+relative:
``residual < max(_CRAIG_TOL_ABS=1e-12, tol * ||rhs||)``. The
absolute floor protects against the pathology where ``||rhs||``
is itself near machine epsilon (e.g. when projecting at a near-KKT
iterate where ``A v`` shrinks at the rate the SQP is converging),
in which case the pure-relative target ``tol * ||rhs||`` would
drop below ``eps`` and convergence could never fire.
Args:
A: Matrix (m x n).
rhs: Right-hand side (m,).
tol: Relative tolerance on ``||A x - rhs|| / ||rhs||``. The
effective convergence threshold also has a hardcoded
absolute floor of ``1e-12``.
max_iter: Maximum bidiagonalization steps.
Returns:
Tuple ``(x, converged)``. ``converged`` is ``True`` only when
the residual fell below the hybrid threshold; it is ``False``
if CRAIG broke down (``alpha`` / ``beta`` below an absolute
threshold, signalling rank deficiency or numerical collapse)
or exhausted its iteration budget. When ``converged`` is
``False`` the returned ``x`` is still the best iterate produced
before the failure.
"""
m, n = A.shape
beta1 = jnp.linalg.norm(rhs)
beta1_safe = jnp.maximum(beta1, 1e-30)
u1 = rhs / beta1_safe
Atu1 = A.T @ u1
alpha1 = jnp.linalg.norm(Atu1)
breakdown_init = alpha1 < _CRAIG_BREAKDOWN_TOL
alpha1_safe = jnp.maximum(alpha1, 1e-30)
v1 = Atu1 / alpha1_safe
s1 = beta1 / alpha1_safe
x1_raw = s1 * v1
# Guard against alpha1 ≈ 0 with beta1 != 0; see legacy comment.
x1 = jnp.where(breakdown_init, jnp.zeros_like(x1_raw), x1_raw)
Av1 = A @ v1
u_hat = Av1 - alpha1 * u1
beta2 = jnp.linalg.norm(u_hat)
beta2_safe = jnp.maximum(beta2, 1e-30)
u2 = u_hat / beta2_safe
# If beta1 is already zero, rhs is zero and x=0 is exact.
trivially_converged = beta1 < tol * jnp.maximum(beta1_safe, 1.0)
residual_init = jnp.abs(beta2 * s1)
init_threshold = jnp.maximum(_CRAIG_TOL_ABS, tol * beta1_safe)
init_converged = trivially_converged | (residual_init < init_threshold)
init_breakdown = breakdown_init & ~trivially_converged
init_state = _CraigState(
x=x1,
s=s1,
u=u2,
v=v1,
alpha=alpha1,
beta=beta2,
residual=residual_init,
converged=init_converged | init_breakdown,
breakdown=init_breakdown,
iteration=jnp.array(1),
)
def craig_step(i, state: _CraigState) -> _CraigState:
def do_step(state: _CraigState) -> _CraigState:
Atu = A.T @ state.u
v_hat = Atu - state.beta * state.v
v_hat = v_hat - jnp.dot(state.v, v_hat) * state.v
alpha_new = jnp.linalg.norm(v_hat)
alpha_breakdown = alpha_new < _CRAIG_BREAKDOWN_TOL
alpha_safe = jnp.maximum(alpha_new, 1e-30)
v_new = v_hat / alpha_safe
s_new = -state.beta * state.s / alpha_safe
# Stay at the last safe iterate when alpha breaks down.
x_candidate = state.x + s_new * v_new
x_new = jnp.where(alpha_breakdown, state.x, x_candidate)
Av = A @ v_new
u_hat = Av - alpha_new * state.u
u_hat = u_hat - jnp.dot(state.u, u_hat) * state.u
beta_new = jnp.linalg.norm(u_hat)
beta_breakdown = beta_new < _CRAIG_BREAKDOWN_TOL
beta_safe = jnp.maximum(beta_new, 1e-30)
u_new = u_hat / beta_safe
residual_new = jnp.abs(beta_new * s_new)
step_threshold = jnp.maximum(_CRAIG_TOL_ABS, tol * beta1_safe)
converged = residual_new < step_threshold
broke = alpha_breakdown | beta_breakdown
done = converged | broke
return _CraigState(
x=x_new,
s=s_new,
u=u_new,
v=v_new,
alpha=alpha_new,
beta=beta_new,
residual=residual_new,
converged=done,
breakdown=state.breakdown | (broke & ~converged),
iteration=state.iteration + 1,
)
return jax.lax.cond(
jnp.reshape(state.converged, ()), lambda s: s, do_step, state
)
final = jax.lax.fori_loop(0, max_iter, craig_step, init_state)
final_threshold = jnp.maximum(_CRAIG_TOL_ABS, tol * beta1_safe)
success = (final.residual < final_threshold) & ~final.breakdown
return final.x, success
# ---------------------------------------------------------------------------
# Stable symmetric Givens rotation (used by MINRES-QLP)
# ---------------------------------------------------------------------------
def _sym_ortho(a: Scalar, b: Scalar) -> tuple[Scalar, Scalar, Scalar]:
"""Numerically stable symmetric Givens rotation (SymOrtho).
Computes (c, s, r) such that r = sqrt(a^2 + b^2) >= 0, c = a/r,
s = b/r. Handles a=0, b=0, and |a|>|b| vs |b|>|a| separately
to avoid overflow/underflow.
Reference: Choi (2006), Table 2.9 / Algorithm 937 (TOMS 2014).
"""
abs_a = jnp.abs(a)
abs_b = jnp.abs(b)
def _b_zero(_: None) -> tuple[Scalar, Scalar, Scalar]:
c = jnp.where(a == 0.0, 1.0, jnp.sign(a))
return c, jnp.array(0.0), abs_a
def _a_zero(_: None) -> tuple[Scalar, Scalar, Scalar]:
return jnp.array(0.0), jnp.sign(b), abs_b
def _both_nonzero(_: None) -> tuple[Scalar, Scalar, Scalar]:
def _a_ge_b(_: None) -> tuple[Scalar, Scalar, Scalar]:
t = b / a
r_local = abs_a * jnp.sqrt(1.0 + t * t)
c_local = jnp.sign(a) / jnp.sqrt(1.0 + t * t)
s_local = c_local * t
return c_local, s_local, r_local
def _b_gt_a(_: None) -> tuple[Scalar, Scalar, Scalar]:
t = a / b
r_local = abs_b * jnp.sqrt(1.0 + t * t)
s_local = jnp.sign(b) / jnp.sqrt(1.0 + t * t)
c_local = s_local * t
return c_local, s_local, r_local
return jax.lax.cond(abs_a >= abs_b, _a_ge_b, _b_gt_a, None)
return jax.lax.cond(
b == 0.0,
_b_zero,
lambda _: jax.lax.cond(a == 0.0, _a_zero, _both_nonzero, None),
None,
)
# ---------------------------------------------------------------------------
# Preconditioned MINRES-QLP (Choi, Paige & Saunders, SISC 2011)
# ---------------------------------------------------------------------------
class _PMinresQLPState(NamedTuple):
"""Internal state for the Preconditioned MINRES-QLP iteration.
Follows the reference implementation by Choi, Paige & Saunders.
Variables use the same names as the reference code for traceability.
"""
# Lanczos vectors (raw, NOT normalized by beta)
r1: Vector
r2: Vector
r3: Vector
# Betas: previous (betal) and current (betan)
betal: Scalar
betan: Scalar
# Left rotation (previous)
cs: Scalar
sn: Scalar
# Right rotation P_{k-2,k}
cr2: Scalar
sr2: Scalar
# QR/QLP intermediates
dltan: Scalar
eplnn: Scalar
gama: Scalar
gamal: Scalar
gamal2: Scalar
# Eta / vepln (for mu recurrence)
eta: Scalar
etal: Scalar
etal2: Scalar
vepln: Scalar
veplnl: Scalar
veplnl2: Scalar
# Tau (for mu recurrence)
tau: Scalar
taul: Scalar
# Mu / u coefficients
u: Scalar
ul: Scalar
ul2: Scalar
ul3: Scalar
# w-vectors and solution
w: Vector
wl: Vector
x: Vector
xl2: Vector
# Residual and norms
phi: Scalar
xl2norm: Scalar
Anorm: Scalar
gmin: Scalar
gminl: Scalar
iteration: Int[Array, ""]
converged: Bool[Array, ""]
[docs]
def pminres_qlp_solve(
matvec: Callable[[Vector], Vector],
rhs: Vector,
tol: float | Scalar = 1e-10,
max_iter: int = 200,
precond: Callable[[Vector], Vector] | None = None,
) -> tuple[Vector, Bool[Array, ""]]:
"""Solve a symmetric (possibly indefinite/singular) system Ax = b.
Implements the full Preconditioned MINRES-QLP algorithm (Table 3.5
of Choi, Paige & Saunders, SIAM J. Sci. Comput. 33(4), 2011).
All iterations use QLP mode (equivalent to TranCond=1 in the
reference implementation).
Args:
matvec: Symmetric operator v -> A @ v.
rhs: Right-hand side vector.
tol: Convergence tolerance on relative residual.
max_iter: Maximum Lanczos iterations.
precond: Optional SPD preconditioner v -> M^{-1} @ v.
Returns:
Tuple of (x, converged).
"""
n = rhs.shape[0]
r2 = rhs
if precond is None:
r3 = r2
beta1 = jnp.linalg.norm(r2)
else:
r3 = precond(r2)
beta1 = jnp.sqrt(jnp.maximum(jnp.dot(r2, r3), 0.0))
beta1_safe = jnp.maximum(beta1, 1e-30)
zeros = jnp.zeros(n)
init_state = _PMinresQLPState(
r1=zeros,
r2=r2,
r3=r3,
betal=jnp.array(0.0),
betan=beta1,
cs=jnp.array(-1.0),
sn=jnp.array(0.0),
cr2=jnp.array(-1.0),
sr2=jnp.array(0.0),
dltan=jnp.array(0.0),
eplnn=jnp.array(0.0),
gama=jnp.array(0.0),
gamal=jnp.array(0.0),
gamal2=jnp.array(0.0),
eta=jnp.array(0.0),
etal=jnp.array(0.0),
etal2=jnp.array(0.0),
vepln=jnp.array(0.0),
veplnl=jnp.array(0.0),
veplnl2=jnp.array(0.0),
tau=jnp.array(0.0),
taul=jnp.array(0.0),
u=jnp.array(0.0),
ul=jnp.array(0.0),
ul2=jnp.array(0.0),
ul3=jnp.array(0.0),
w=zeros,
wl=zeros,
x=zeros,
xl2=zeros,
phi=beta1,
xl2norm=jnp.array(0.0),
Anorm=jnp.array(0.0),
gmin=jnp.array(0.0),
gminl=jnp.array(0.0),
iteration=jnp.array(0),
converged=beta1 < 1e-30,
)
def step_fn(_i: int, state: _PMinresQLPState) -> _PMinresQLPState:
def do_step(state: _PMinresQLPState) -> _PMinresQLPState:
k = state.iteration + 1 # 1-based iteration count
# --- Lanczos step ---
betal = state.betal
beta = state.betan
beta_safe = jnp.maximum(beta, 1e-30)
betal_safe = jnp.maximum(betal, 1e-30)
v = state.r3 / beta_safe
r3_new = matvec(v)
r3_new = r3_new - jnp.where(
k > 1,
state.r1 * (beta / betal_safe),
zeros,
)
alfa = jnp.dot(r3_new, v)
r3_new = r3_new - state.r2 * (alfa / beta_safe)
r1_new = state.r2
r2_new = r3_new
if precond is None:
betan_new = jnp.linalg.norm(r3_new)
else:
r3_new = precond(r2_new)
betan_new = jnp.sqrt(jnp.maximum(jnp.dot(r2_new, r3_new), 0.0))
pnorm = jnp.sqrt(betal**2 + alfa**2 + betan_new**2)
# --- Previous left rotation Q_{k-1} ---
dbar = state.dltan
dlta = state.cs * dbar + state.sn * alfa
gbar = state.sn * dbar - state.cs * alfa
eplnn_new = state.sn * betan_new
dltan_new = -state.cs * betan_new
# --- Current left rotation Q_k ---
gamal2 = state.gamal
gamal = state.gama
cs_new, sn_new, gama_new = _sym_ortho(gbar, betan_new)
taul2 = state.taul
taul_new = state.tau
tau_new = cs_new * state.phi
phi_new = sn_new * state.phi
# --- Previous right rotation P_{k-2,k} (active when k > 2) ---
veplnl2 = state.veplnl
etal2 = state.etal
etal_new = state.eta
dlta_tmp = state.sr2 * state.vepln - state.cr2 * dlta
veplnl_new = state.cr2 * state.vepln + state.sr2 * dlta
dlta_k2 = jnp.where(k > 2, dlta_tmp, dlta)
veplnl_new = jnp.where(k > 2, veplnl_new, state.veplnl)
etal_new = jnp.where(k > 2, etal_new, state.etal)
eta_new = jnp.where(k > 2, state.sr2 * gama_new, jnp.array(0.0))
gama_k2 = jnp.where(k > 2, -state.cr2 * gama_new, gama_new)
# --- Current right rotation P_{k-1,k} (active when k > 1) ---
cr1_new, sr1_new, gamal_new = _sym_ortho(gamal, dlta_k2)
cr1_new = jnp.where(k > 1, cr1_new, jnp.array(-1.0))
sr1_new = jnp.where(k > 1, sr1_new, jnp.array(0.0))
gamal_new = jnp.where(k > 1, gamal_new, gamal)
vepln_new = jnp.where(k > 1, sr1_new * gama_k2, jnp.array(0.0))
gama_final = jnp.where(k > 1, -cr1_new * gama_k2, gama_k2)
# --- Update mu coefficients ---
ul4 = state.ul3
ul3_new = state.ul2
gamal2_safe = jnp.where(jnp.abs(gamal2) > 1e-30, gamal2, 1e-30)
ul2_new = jnp.where(
k > 2,
(taul2 - etal2 * ul4 - veplnl2 * ul3_new) / gamal2_safe,
state.ul2,
)
gamal_safe = jnp.where(jnp.abs(gamal_new) > 1e-30, gamal_new, 1e-30)
ul_new = jnp.where(
k > 1,
(taul_new - etal_new * ul3_new - veplnl_new * ul2_new) / gamal_safe,
state.ul,
)
gama_safe = jnp.where(jnp.abs(gama_final) > 1e-30, gama_final, 1e-30)
u_new = jnp.where(
jnp.abs(gama_final) > 1e-30,
(tau_new - eta_new * ul2_new - vepln_new * ul_new) / gama_safe,
jnp.array(0.0),
)
xl2norm_new = jnp.sqrt(state.xl2norm**2 + ul2_new**2)
# --- Update w-vectors and solution (QLP mode) ---
w_old = state.w
wl_old = state.wl
# k > 2 path (general case)
wl2_g = wl_old
wl_g = w_old
w_g = wl2_g * state.sr2 - v * state.cr2
wl2_g = wl2_g * state.cr2 + v * state.sr2
v_tmp = wl_g * cr1_new + w_g * sr1_new
w_g = wl_g * sr1_new - w_g * cr1_new
wl_g = v_tmp
# k == 2 path
wl2_2 = wl_old
wl_2 = w_old * cr1_new + v * sr1_new
w_2 = w_old * sr1_new - v * cr1_new
# k == 1 path
wl2_1 = wl_old
wl_1 = v * sr1_new
w_1 = -v * cr1_new
wl2_out = jnp.where(k > 2, wl2_g, jnp.where(k == 2, wl2_2, wl2_1))
wl_out = jnp.where(k > 2, wl_g, jnp.where(k == 2, wl_2, wl_1))
w_out = jnp.where(k > 2, w_g, jnp.where(k == 2, w_2, w_1))
xl2_new = state.xl2 + wl2_out * ul2_new
x_new = xl2_new + wl_out * ul_new + w_out * u_new
# --- Next right rotation P_{k-1,k+1} (for next iter) ---
cr2_new, sr2_new, gamal_store = _sym_ortho(gamal_new, eplnn_new)
# --- Update norms and condition estimate ---
abs_gama = jnp.abs(gama_final)
Anorm_new = jnp.maximum(state.Anorm, pnorm)
Anorm_new = jnp.maximum(Anorm_new, gamal_new)
Anorm_new = jnp.maximum(Anorm_new, abs_gama)
gminl_new = jnp.where(k == 1, gama_final, state.gmin)
gmin_new = jnp.where(
k == 1,
gama_final,
jnp.minimum(
jnp.minimum(state.gminl, gamal_new),
abs_gama,
),
)
# --- Convergence check ---
xnorm = jnp.sqrt(xl2norm_new**2 + ul_new**2 + u_new**2)
relres = jnp.abs(phi_new) / (
Anorm_new * jnp.maximum(xnorm, 1e-30) + beta1_safe
)
lanczos_breakdown = betan_new < 1e-30 * jnp.maximum(beta1_safe, 1.0)
residual_small = jnp.abs(phi_new) < tol * beta1_safe
converged = (relres < tol) | (lanczos_breakdown & residual_small)
stop_now = converged | lanczos_breakdown
return _PMinresQLPState(
r1=r1_new,
r2=r2_new,
r3=r3_new,
betal=beta,
betan=betan_new,
cs=cs_new,
sn=sn_new,
cr2=cr2_new,
sr2=sr2_new,
dltan=dltan_new,
eplnn=eplnn_new,
gama=gama_final,
gamal=gamal_store,
gamal2=gamal_new,
eta=eta_new,
etal=etal_new,
etal2=etal2,
vepln=vepln_new,
veplnl=veplnl_new,
veplnl2=veplnl2,
tau=tau_new,
taul=taul_new,
u=u_new,
ul=ul_new,
ul2=ul2_new,
ul3=ul3_new,
w=w_out,
wl=wl_out,
x=x_new,
xl2=xl2_new,
phi=phi_new,
xl2norm=xl2norm_new,
Anorm=Anorm_new,
gmin=gmin_new,
gminl=gminl_new,
iteration=state.iteration + 1,
converged=stop_now,
)
return jax.lax.cond(
jnp.reshape(state.converged, ()), lambda s: s, do_step, state
)
final = jax.lax.fori_loop(0, max_iter, step_fn, init_state)
final_relres = jnp.abs(final.phi) / (
jnp.maximum(final.Anorm, 1e-30) * jnp.maximum(jnp.linalg.norm(final.x), 1e-30)
+ beta1_safe
)
success = (final_relres < tol) & jnp.isfinite(final.x).all()
return final.x, success
__all__ = [
"_CGState",
"_CRAIG_BREAKDOWN_TOL",
"_CRAIG_TOL_ABS",
"_CraigState",
"_PMinresQLPState",
"_sym_ortho",
"build_cg_step",
"craig_solve",
"pminres_qlp_solve",
"solve_unconstrained_cg",
]