Source code for slsqp_jax.qp.active_set

"""Shared add / drop / EXPAND / ping-pong active-set loop.

The legacy ``slsqp_jax/qp_solver.py`` carried three near-identical copies
of this loop body — one in ``_solve_qp_proximal``, one in
``_solve_qp_direct``, and one inlined in ``solve_qp``'s ineq-only path.
This module hosts the *single* implementation; the three QP-strategy
modules now build their own initial state and inner-solve closure and
delegate to :func:`run_active_set_loop`.

The loop is parameterised by the **inner solve closure**:

    inner_solve_fn(state) -> ActiveSetInnerResult

which packages the new direction, equality / inequality multipliers,
projection diagnostics, and the inner-failure flag.  Callers control
how the inner solver is wired (proximal-stabilised HVP vs. direct
projection) by what they put in this closure.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import NamedTuple

import jax
import jax.numpy as jnp
from jaxtyping import Array, Bool, Float

from slsqp_jax.state import QPState
from slsqp_jax.types import Scalar, Vector


[docs] class ActiveSetInnerResult(NamedTuple): """Per-iteration payload returned by the caller's inner-solve closure. Attributes: d: New search direction. multipliers_eq: Recovered (or absorbed) equality multipliers for *this* iteration. Pass an empty ``(0,)`` vector when no equality constraints are present. multipliers_ineq: Recovered inequality multipliers (length ``m_ineq``); zero on inactive entries. inner_failed: Whether the inner solve produced a non-finite direction. proj_residual: Latest M-metric projection residual from the inner solver (``MinresQLPSolver``); always ``0`` for null-space CG / CRAIG. n_proj_refinements: Refinement rounds applied by the inner solver on this call. Accumulated by the loop into the running QPState count. projected_grad_norm: Latest projected-gradient norm from the inner solver (``HRInexactSTCG`` only; ``inf`` otherwise). """ d: Vector multipliers_eq: Float[Array, " m_eq"] multipliers_ineq: Float[Array, " m_ineq"] inner_failed: Bool[Array, ""] proj_residual: Scalar n_proj_refinements: Array projected_grad_norm: Scalar
[docs] def run_active_set_loop( init_state: QPState, inner_solve_fn: Callable[[QPState, Scalar], ActiveSetInnerResult], A_ineq: Float[Array, "m_ineq n"], b_ineq: Float[Array, " m_ineq"], max_iter: int, base_tol: Scalar, effective_tau: Scalar | float, drop_floor: Scalar, ping_pong_threshold: int, ) -> QPState: """Run the shared add / drop / EXPAND / ping-pong active-set loop. The body, identical across all three QP strategies: 1. Compute the working tolerance ``working_tol = base_tol + iteration * effective_tau`` (EXPAND ramp). 2. Call the caller-supplied :func:`inner_solve_fn` with the current active set; receive a new direction, multipliers, and projection diagnostics. 3. Compute violation scores ``A_ineq d - b_ineq`` and drop scores ``mult_ineq < -max(working_tol, drop_floor)``. 4. Branch on add / drop / mark_converged with the same ping-pong short-circuit logic that the legacy bodies used. Args: init_state: Initial active-set state. inner_solve_fn: Closure performing the per-iteration inner equality-constrained QP solve. Receives the current ``QPState`` and the current ``working_tol`` (some callers ignore it). A_ineq: Inequality constraint matrix used for the violation check. Must match what ``inner_solve_fn`` consumes. b_ineq: Inequality RHS for the violation check. max_iter: Iteration budget (matches ``while_loop`` cond). base_tol: Outer SQP-level base tolerance. effective_tau: Per-iteration EXPAND increment. ``0.0`` disables the ramp. drop_floor: Floor on the drop test so multiplier-recovery noise does not flip a negligible negative multiplier into a drop. ping_pong_threshold: Threshold for the explicit add/drop ping-pong short-circuit (``2**31 - 1`` effectively disables it). Returns: The final ``QPState`` after the active-set loop terminates. """ def cond_fn(state: QPState) -> Bool[Array, ""]: return ~state.converged & (state.iteration < max_iter) def body_fn(state: QPState) -> QPState: working_tol = base_tol + state.iteration * effective_tau inner = inner_solve_fn(state, working_tol) d_new = inner.d mult_eq_new = inner.multipliers_eq mult_ineq_new = inner.multipliers_ineq new_any_inner_failure = state.any_inner_failure | inner.inner_failed proj_residual_new = inner.proj_residual n_proj_refinements_new = state.n_proj_refinements + inner.n_proj_refinements projected_grad_norm_new = inner.projected_grad_norm # Feasibility check with EXPAND-relaxed tolerance. residuals = A_ineq @ d_new - b_ineq violated = (residuals < -working_tol) & ~state.active_set any_violated = jnp.any(violated) violation_scores = jnp.where(violated, -residuals, -jnp.inf) most_violated_idx = jnp.argmax(violation_scores) # Noise-aware drop test. drop_tol = jnp.maximum(working_tol, drop_floor) negative_mult = (mult_ineq_new < -drop_tol) & state.active_set any_negative = jnp.any(negative_mult) mult_scores = jnp.where(state.active_set, mult_ineq_new, jnp.inf) most_negative_idx = jnp.argmin(mult_scores) def add_constraint() -> QPState: is_pp = (state.last_drop_idx >= 0) & ( most_violated_idx == state.last_drop_idx ) new_pp_count = jnp.where(is_pp, state.ping_pong_count + 1, 0) triggered = new_pp_count >= ping_pong_threshold new_active = jnp.where( triggered, state.active_set, state.active_set.at[most_violated_idx].set(True), ) return QPState( # ty: ignore[invalid-return-type] d=d_new, active_set=new_active, multipliers_eq=mult_eq_new, multipliers_ineq=mult_ineq_new, iteration=state.iteration + 1, converged=triggered, any_inner_failure=new_any_inner_failure, last_add_idx=jnp.where( triggered, state.last_add_idx, most_violated_idx ), last_drop_idx=state.last_drop_idx, ping_pong_count=jnp.where( triggered, state.ping_pong_count, new_pp_count ), ping_ponged=state.ping_ponged | triggered, proj_residual=proj_residual_new, n_proj_refinements=n_proj_refinements_new, projected_grad_norm=projected_grad_norm_new, ) def drop_constraint() -> QPState: is_pp = (state.last_add_idx >= 0) & ( most_negative_idx == state.last_add_idx ) new_pp_count = jnp.where(is_pp, state.ping_pong_count + 1, 0) triggered = new_pp_count >= ping_pong_threshold new_active = jnp.where( triggered, state.active_set, state.active_set.at[most_negative_idx].set(False), ) return QPState( # ty: ignore[invalid-return-type] d=d_new, active_set=new_active, multipliers_eq=mult_eq_new, multipliers_ineq=mult_ineq_new, iteration=state.iteration + 1, converged=triggered, any_inner_failure=new_any_inner_failure, last_add_idx=state.last_add_idx, last_drop_idx=jnp.where( triggered, state.last_drop_idx, most_negative_idx ), ping_pong_count=jnp.where( triggered, state.ping_pong_count, new_pp_count ), ping_ponged=state.ping_ponged | triggered, proj_residual=proj_residual_new, n_proj_refinements=n_proj_refinements_new, projected_grad_norm=projected_grad_norm_new, ) def mark_converged() -> QPState: return QPState( # ty: ignore[invalid-return-type] d=d_new, active_set=state.active_set, multipliers_eq=mult_eq_new, multipliers_ineq=mult_ineq_new, iteration=state.iteration + 1, converged=jnp.array(True), any_inner_failure=new_any_inner_failure, last_add_idx=state.last_add_idx, last_drop_idx=state.last_drop_idx, ping_pong_count=state.ping_pong_count, ping_ponged=state.ping_ponged, proj_residual=proj_residual_new, n_proj_refinements=n_proj_refinements_new, projected_grad_norm=projected_grad_norm_new, ) return jax.lax.cond( any_violated, add_constraint, lambda: jax.lax.cond(any_negative, drop_constraint, mark_converged), ) return jax.lax.while_loop(cond_fn, body_fn, init_state)
__all__ = ["ActiveSetInnerResult", "run_active_set_loop"]