Source code for slsqp_jax.inner.base
"""Abstract base class for pluggable inner equality-constrained QP solvers."""
from __future__ import annotations
import abc
from collections.abc import Callable
import equinox as eqx
from jaxtyping import Array, Bool, Float
from slsqp_jax.state import InnerSolveResult, ProjectionContext
from slsqp_jax.types import Scalar, Vector
[docs]
class AbstractInnerSolver(eqx.Module):
"""Strategy for solving the equality-constrained QP subproblem.
Subclasses implement ``solve`` to compute the search direction ``d``
and Lagrange multipliers for the active constraints.
"""
[docs]
@abc.abstractmethod
def solve(
self,
hvp_fn: Callable[[Vector], Vector],
g: Vector,
A: Float[Array, "m n"],
b: Float[Array, " m"],
active_mask: Bool[Array, " m"],
precond_fn: Callable[[Vector], Vector] | None = None,
free_mask: Bool[Array, " n"] | None = None,
d_fixed: Vector | None = None,
adaptive_tol: Scalar | float | None = None,
) -> InnerSolveResult:
"""Solve the equality-constrained QP subproblem.
Solves::
minimize (1/2) d^T B d + g^T d
subject to A[active] d = b[active]
d[i] = d_fixed[i] for i where free_mask[i] is False
where B is given implicitly via ``hvp_fn(v) = B @ v``.
Args:
hvp_fn: Hessian-vector product function v -> B @ v.
g: Linear term (gradient of objective).
A: Combined constraint matrix (m x n).
b: Combined RHS vector (m,).
active_mask: Boolean mask (m,) indicating active constraints.
precond_fn: Optional preconditioner v -> M @ v where M ~ B^{-1}.
free_mask: Optional boolean mask (n,). When provided, only
variables with ``free_mask[i] = True`` are optimized.
d_fixed: Values for fixed variables (n,). Required when
``free_mask`` is provided.
adaptive_tol: Optional Eisenstat-Walker tolerance override.
When provided, overrides the solver's default convergence
tolerance for this call only.
Returns:
``InnerSolveResult`` with the direction, multipliers, and
convergence flag.
"""
... # pragma: no cover
[docs]
def build_projection_context(
self,
hvp_fn: Callable[[Vector], Vector],
g: Vector,
A: Float[Array, "m n"],
b: Float[Array, " m"],
active_mask: Bool[Array, " m"],
precond_fn: Callable[[Vector], Vector] | None = None,
free_mask: Bool[Array, " n"] | None = None,
d_fixed: Vector | None = None,
) -> ProjectionContext:
"""Build a reusable projector + multiplier-recovery context.
Composed strategies (e.g. ``HRInexactSTCG``) call this on the
underlying inner solver to obtain its null-space projector,
particular solution and multiplier-recovery closure without
running the projector's own CG loop.
The default implementation raises ``NotImplementedError`` so
full-KKT solvers (``MinresQLPSolver``) cleanly opt out — they
have no separate projection step and therefore cannot supply
the inexact-projector ``W̃_k`` that HR Algorithm 4.5 needs.
"""
raise NotImplementedError(
f"{type(self).__name__} does not expose a projection context; "
"it cannot be used as the inner projector for HRInexactSTCG."
)
__all__ = ["AbstractInnerSolver"]