Source code for slsqp_jax.qp.api

"""Thin :func:`solve_qp` router dispatching to the three QP strategies.

The legacy ``solve_qp`` did three things at once: routed to the proximal
or direct equality strategy, inlined the inequality-only strategy, and
constructed a default ``ProjectedCGCholesky`` when none was provided.
This module reduces it to the routing table and the default-inner-solver
construction.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import cast

import jax.numpy as jnp
from beartype import beartype
from jaxtyping import Array, Bool, Float, jaxtyped

from slsqp_jax.inner.base import AbstractInnerSolver
from slsqp_jax.inner.cholesky import ProjectedCGCholesky
from slsqp_jax.inner.krylov import solve_unconstrained_cg
from slsqp_jax.qp.direct import solve_qp_direct
from slsqp_jax.qp.inequality import solve_qp_inequality
from slsqp_jax.qp.proximal import solve_qp_proximal
from slsqp_jax.state import QPSolverResult
from slsqp_jax.types import Scalar, Vector


[docs] @jaxtyped(typechecker=beartype) def solve_qp( hvp_fn: Callable, g: Vector, A_eq: Float[Array, "m_eq n"], b_eq: Float[Array, " m_eq"], A_ineq: Float[Array, "m_ineq n"], b_ineq: Float[Array, " m_ineq"], max_iter: int = 100, max_cg_iter: int = 50, tol: Scalar | float = 1e-8, expand_factor: float = 1.0, initial_active_set: Bool[Array, " m_ineq"] | None = None, kkt_residual: Scalar | float = 0.0, proximal_mu: Scalar | float = 0.0, prev_multipliers_eq: Float[Array, " m_eq"] | None = None, precond_fn: Callable | None = None, cg_tol: Scalar | float | None = None, cg_regularization: float = 1e-6, use_proximal: bool = True, predicted_active_set: Bool[Array, " m_ineq"] | None = None, active_set_method: str = "expand", use_constraint_preconditioner: bool = False, inner_solver: AbstractInnerSolver | None = None, mult_drop_floor: float = 1e-6, ping_pong_threshold: int = 2**31 - 1, ) -> QPSolverResult: """Solve a QP with equality and inequality constraints. Solves:: minimize (1/2) d^T H d + g^T d subject to A_eq d = b_eq A_ineq d >= b_ineq where H is provided implicitly via ``hvp_fn(v) = H @ v``. The QP active-set loop dispatches to one of three strategies: * **Proximal sSQP** (``m_eq > 0`` and ``use_proximal=True``): equality constraints absorbed into the objective via augmented-Lagrangian penalty. See :mod:`slsqp_jax.qp.proximal`. * **Direct projection** (``m_eq > 0`` and ``use_proximal=False``): equality constraints enforced via null-space projection in the inner solver. See :mod:`slsqp_jax.qp.direct`. * **Inequality-only** (``m_eq == 0``): no equality block; just an active-set loop on inequality constraints. See :mod:`slsqp_jax.qp.inequality`. All three share the same EXPAND / ping-pong active-set loop body (see :mod:`slsqp_jax.qp.active_set`). Args: hvp_fn: Hessian-vector product function v -> H @ v. g: Linear term of the objective (gradient). A_eq: Equality constraint matrix (m_eq x n). b_eq: Equality constraint RHS (m_eq,). A_ineq: Inequality constraint matrix (m_ineq x n). b_ineq: Inequality constraint RHS (m_ineq,). max_iter: Maximum active-set iterations. max_cg_iter: Maximum CG iterations per active-set step. tol: Feasibility and optimality tolerance. expand_factor: EXPAND tolerance growth rate. initial_active_set: Optional warm-start active set from a previous QP solve. kkt_residual: Norm of the KKT residual from the outer solver. proximal_mu: Adaptive proximal parameter for sSQP. prev_multipliers_eq: Equality multipliers from the previous outer iteration (proximal centre). precond_fn: Optional preconditioner v -> M @ v. cg_tol: Optional CG convergence tolerance overriding ``tol`` for the inner solver only. cg_regularization: Curvature-guard threshold for CG. use_proximal: When True, equality constraints go through the sSQP proximal path. When False, direct projection. predicted_active_set: Optional LPEC-A predicted active set for warm-start. active_set_method: ``"expand"``, ``"lpeca_init"``, or ``"lpeca"``. use_constraint_preconditioner: Used only when constructing a default ``inner_solver``. inner_solver: Pluggable strategy for the inner equality- constrained QP solve. Defaults to ``ProjectedCGCholesky``. mult_drop_floor: Floor on the negative-multiplier drop test. ping_pong_threshold: Threshold for the explicit ping-pong short-circuit. Defaults to ``2**31 - 1`` (effectively disabled). Returns: ``QPSolverResult`` containing the solution, multipliers, active set, and convergence info. """ if active_set_method not in ("expand", "lpeca_init", "lpeca"): raise ValueError( f"active_set_method must be 'expand', 'lpeca_init', or 'lpeca', " f"got {active_set_method!r}" ) m_eq = A_eq.shape[0] m_ineq = A_ineq.shape[0] m_total = m_eq + m_ineq # Squash any size-1 leak from the caller (e.g. an SLSQP solver # whose ``atol`` was constructed as ``jnp.array([1e-6])``). # Downstream the active-set loop combines ``tol`` with EXPAND # constants and uses it inside JAX comparisons; a ``(1,)``-shape # value broadcasts the comparison output to ``(1,)`` and triggers # ``TypeError: Pred must be a scalar`` deep inside ``jax.lax.cond``. # The matching scalarisation for ``mu`` lives in # ``SLSQP._solve_qp_subproblem`` (the # ``test_size_one_atol_does_not_leak_to_proximal_mu`` regression). if isinstance(tol, Array): tol = jnp.reshape(tol, ()) use_expand = active_set_method != "lpeca" effective_predicted = ( predicted_active_set if active_set_method in ("lpeca_init", "lpeca") else None ) inner_cg_tol: Scalar | float = cg_tol if cg_tol is not None else tol if inner_solver is None: inner_solver = cast( AbstractInnerSolver, ProjectedCGCholesky( max_cg_iter=max_cg_iter, cg_tol=inner_cg_tol, cg_regularization=cg_regularization, use_constraint_preconditioner=use_constraint_preconditioner, ), ) # Case 1: No constraints at all — truncated CG is always valid. if m_total == 0: d, _cg_converged = solve_unconstrained_cg( hvp_fn, g, max_cg_iter, inner_cg_tol, precond_fn=precond_fn, cg_regularization=cg_regularization, ) finite_d = jnp.isfinite(d).all() return QPSolverResult( d=d, multipliers_eq=jnp.zeros((0,)), multipliers_ineq=jnp.zeros((0,)), active_set=jnp.zeros((0,), dtype=bool), converged=finite_d, iterations=jnp.array(1), ping_ponged=jnp.array(False), reached_max_iter=jnp.array(False), final_working_tol=jnp.asarray(0.0, dtype=jnp.float64), proj_residual=jnp.asarray(0.0, dtype=jnp.float64), n_proj_refinements=jnp.asarray(0), projected_grad_norm=jnp.asarray(jnp.inf, dtype=jnp.float64), ) # Case 2: equality + (any) ineq, sSQP enabled. if m_eq > 0 and use_proximal: return solve_qp_proximal( hvp_fn=hvp_fn, g=g, A_eq=A_eq, b_eq=b_eq, A_ineq=A_ineq, b_ineq=b_ineq, m_eq=m_eq, m_ineq=m_ineq, max_iter=max_iter, max_cg_iter=max_cg_iter, tol=tol, expand_factor=expand_factor, initial_active_set=initial_active_set, kkt_residual=kkt_residual, proximal_mu=proximal_mu, prev_multipliers_eq=prev_multipliers_eq, inner_solver=inner_solver, precond_fn=precond_fn, cg_tol=inner_cg_tol, cg_regularization=cg_regularization, predicted_active_set=effective_predicted, use_expand=use_expand, mult_drop_floor=mult_drop_floor, ping_pong_threshold=ping_pong_threshold, ) # Case 3: equality + (any) ineq, sSQP disabled — direct projection. if m_eq > 0 and not use_proximal: return solve_qp_direct( hvp_fn=hvp_fn, g=g, A_eq=A_eq, b_eq=b_eq, A_ineq=A_ineq, b_ineq=b_ineq, m_eq=m_eq, m_ineq=m_ineq, max_iter=max_iter, tol=tol, expand_factor=expand_factor, initial_active_set=initial_active_set, kkt_residual=kkt_residual, inner_solver=inner_solver, precond_fn=precond_fn, cg_tol=inner_cg_tol, cg_regularization=cg_regularization, predicted_active_set=effective_predicted, use_expand=use_expand, mult_drop_floor=mult_drop_floor, ping_pong_threshold=ping_pong_threshold, ) # Case 4: inequality only. return solve_qp_inequality( hvp_fn=hvp_fn, g=g, A_ineq=A_ineq, b_ineq=b_ineq, m_ineq=m_ineq, max_iter=max_iter, max_cg_iter=max_cg_iter, tol=tol, expand_factor=expand_factor, initial_active_set=initial_active_set, kkt_residual=kkt_residual, inner_solver=inner_solver, precond_fn=precond_fn, cg_tol=inner_cg_tol, cg_regularization=cg_regularization, predicted_active_set=effective_predicted, use_expand=use_expand, mult_drop_floor=mult_drop_floor, ping_pong_threshold=ping_pong_threshold, )
__all__ = ["solve_qp"]