Source code for slsqp_jax.inner.cholesky

"""Projected CG with Cholesky-based null-space projection."""

from __future__ import annotations

from collections.abc import Callable

import jax
import jax.numpy as jnp
from jaxtyping import Array, Bool, Float

from slsqp_jax.inner.base import AbstractInnerSolver
from slsqp_jax.inner.masking import make_active_subproblem
from slsqp_jax.inner.projected_cg import run_projected_pcg
from slsqp_jax.state import InnerSolveResult, ProjectionContext
from slsqp_jax.types import Scalar, Vector


def _make_cholesky_projection_ctx(
    hvp_fn: Callable[[Vector], Vector],
    g: Vector,
    A: Float[Array, "m n"],
    b: Float[Array, " m"],
    active_mask: Bool[Array, " m"],
    free_mask: Bool[Array, " n"] | None = None,
    d_fixed: Vector | None = None,
) -> ProjectionContext:
    """Build a ``ProjectionContext`` backed by a regularised Cholesky
    factorisation of ``A_work A_workᵀ``.

    Mirrors the projector-construction prefix of
    ``ProjectedCGCholesky.solve``: masks ``A`` and ``b`` to the active
    rows, applies bound-fixing, factorises ``AAᵀ + 1e-8·I``, and packages
    the resulting projector, particular solution and multiplier-recovery
    closure (with one round of iterative refinement) into a
    ``ProjectionContext`` for reuse.
    """
    sub = make_active_subproblem(
        hvp_fn=hvp_fn,
        g=g,
        A=A,
        b=b,
        active_mask=active_mask,
        free_mask=free_mask,
        d_fixed=d_fixed,
    )
    m = A.shape[0]

    reg_diag = jnp.where(active_mask, 0.0, 1.0)
    AAt = sub.A_work @ sub.A_work.T + jnp.diag(reg_diag) + 1e-8 * jnp.eye(m)
    AAt_chol = jnp.linalg.cholesky(AAt)

    def solve_AAt(rhs: Float[Array, " m"]) -> Float[Array, " m"]:
        return jax.scipy.linalg.cho_solve((AAt_chol, True), rhs)

    d_p_free = sub.A_work.T @ solve_AAt(sub.b_work)
    d_p = d_p_free + sub.d_fixed if sub.has_fixed else d_p_free

    def project(v: Vector) -> Vector:
        v_work = sub.free_mask * v if sub.has_fixed else v
        return v_work - sub.A_work.T @ solve_AAt(sub.A_work @ v_work)

    def recover_multipliers(Bd_plus_g: Vector) -> Float[Array, " m"]:
        # KKT recovery with one step of iterative refinement to absorb
        # the O(eps · cond(AAt)) error introduced by the 1e-8 ridge.
        kkt_rhs = sub.A_work @ Bd_plus_g
        mult = solve_AAt(kkt_rhs)
        mult = jnp.where(active_mask, mult, 0.0)
        grad_L_qp = Bd_plus_g - sub.A_work.T @ mult
        delta = solve_AAt(sub.A_work @ grad_L_qp)
        mult = mult + delta
        mult = jnp.where(active_mask, mult, 0.0)
        return mult

    return ProjectionContext(
        project=project,
        d_p=d_p,
        recover_multipliers=recover_multipliers,
        hvp_work=sub.hvp_work,
        g_eff=sub.g_eff,
        A_work=sub.A_work,
        free_mask=sub.free_mask,
        d_fixed=sub.d_fixed,
        has_fixed=sub.has_fixed,
        converged=jnp.asarray(True),
    )


def _solve_projected_cg_cholesky(
    hvp_fn: Callable[[Vector], Vector],
    g: Vector,
    A: Float[Array, "m n"],
    b: Float[Array, " m"],
    active_mask: Bool[Array, " m"],
    max_cg_iter: int,
    cg_tol: Scalar | float,
    precond_fn: Callable[[Vector], Vector] | None = None,
    cg_regularization: float = 1e-6,
    free_mask: Bool[Array, " n"] | None = None,
    d_fixed: Vector | None = None,
    use_constraint_preconditioner: bool = False,
) -> tuple[Vector, Float[Array, " m"], Bool[Array, ""]]:
    """Solve equality-constrained QP using projected (preconditioned) CG.

    Implementation backing :class:`ProjectedCGCholesky`.  See the class
    docstring for the algorithmic description.

    When ``use_constraint_preconditioner`` is ``True`` and a
    preconditioner is provided, the constraint preconditioner
    (Gould, Hribar & Nocedal, 2001) is wrapped in front of the shared
    PCG driver: ``z = M r - M A^T (A M A^T)^{-1} A M r``.
    """
    m = A.shape[0]

    ctx = _make_cholesky_projection_ctx(
        hvp_fn=hvp_fn,
        g=g,
        A=A,
        b=b,
        active_mask=active_mask,
        free_mask=free_mask,
        d_fixed=d_fixed,
    )

    if precond_fn is not None and use_constraint_preconditioner:
        _raw_precond = precond_fn
        reg_diag = jnp.where(active_mask, 0.0, 1.0)
        M_AT = jax.vmap(_raw_precond)(ctx.A_work).T  # (n, m)
        A_M_AT = ctx.A_work @ M_AT + jnp.diag(reg_diag) + 1e-8 * jnp.eye(m)
        A_M_AT_chol = jnp.linalg.cholesky(A_M_AT)

        def _solve_AMAT(rhs: Float[Array, " m"]) -> Float[Array, " m"]:
            return jax.scipy.linalg.cho_solve((A_M_AT_chol, True), rhs)

        def _constraint_precond(r: Vector) -> Vector:
            Mr = _raw_precond(r)
            w = _solve_AMAT(ctx.A_work @ Mr)
            return Mr - M_AT @ w

        effective_precond: Callable[[Vector], Vector] | None = _constraint_precond
    else:
        effective_precond = precond_fn

    return run_projected_pcg(
        ctx=ctx,
        hvp_fn=hvp_fn,
        g=g,
        max_cg_iter=max_cg_iter,
        cg_tol=cg_tol,
        effective_precond=effective_precond,
        cg_regularization=cg_regularization,
    )


[docs] class ProjectedCGCholesky(AbstractInnerSolver): """Projected CG with Cholesky-based null-space projection. This is the original implementation: Cholesky-factor ``A A^T`` (with regularization), use it for the null-space projector and particular solution, run CG in the null space, and recover multipliers via iterative refinement. When ``use_constraint_preconditioner`` is ``True`` and a preconditioner is provided, the constraint preconditioner (Gould, Hribar & Nocedal, 2001) is used instead of the naive ``P(M(r))``. """ max_cg_iter: int cg_tol: Scalar | float cg_regularization: float = 1e-6 use_constraint_preconditioner: bool = False
[docs] def solve( self, hvp_fn: Callable[[Vector], Vector], g: Vector, A: Float[Array, "m n"], b: Float[Array, " m"], active_mask: Bool[Array, " m"], precond_fn: Callable[[Vector], Vector] | None = None, free_mask: Bool[Array, " n"] | None = None, d_fixed: Vector | None = None, adaptive_tol: Scalar | float | None = None, ) -> InnerSolveResult: effective_tol = adaptive_tol if adaptive_tol is not None else self.cg_tol d, multipliers, converged = _solve_projected_cg_cholesky( hvp_fn=hvp_fn, g=g, A=A, b=b, active_mask=active_mask, max_cg_iter=self.max_cg_iter, cg_tol=effective_tol, precond_fn=precond_fn, cg_regularization=self.cg_regularization, free_mask=free_mask, d_fixed=d_fixed, use_constraint_preconditioner=self.use_constraint_preconditioner, ) # Null-space CG enforces ``A d = b`` structurally; the residual # is at floating-point floor by construction. ``inf`` projected # gradient norm so the inexact-stationarity test cannot trip on # a non-HR inner solver. return InnerSolveResult( d=d, multipliers=multipliers, converged=converged, proj_residual=jnp.asarray(0.0, dtype=d.dtype), n_proj_refinements=jnp.asarray(0), projected_grad_norm=jnp.asarray(jnp.inf, dtype=d.dtype), )
[docs] def build_projection_context( self, hvp_fn: Callable[[Vector], Vector], g: Vector, A: Float[Array, "m n"], b: Float[Array, " m"], active_mask: Bool[Array, " m"], precond_fn: Callable[[Vector], Vector] | None = None, free_mask: Bool[Array, " n"] | None = None, d_fixed: Vector | None = None, ) -> ProjectionContext: return _make_cholesky_projection_ctx( hvp_fn=hvp_fn, g=g, A=A, b=b, active_mask=active_mask, free_mask=free_mask, d_fixed=d_fixed, )
__all__ = ["ProjectedCGCholesky"]