Source code for slsqp_jax.lpeca

"""LPEC-A Active Set Identification for SLSQP.

Implements the LPEC-A (Linear Program with Equilibrium Constraints
Approximation) method from Oberlin & Wright (2005, Section 3.3) for
identifying active inequality constraints in nonlinear programming.

The method computes a proximity measure ``rho_bar`` from primal
constraint values and dual multiplier estimates, then applies a
threshold test to predict the active set.  Under Mangasarian-Fromovitz
constraint qualification (MFCQ) and second-order sufficiency conditions,
the prediction is asymptotically exact as the iterate converges to
the solution.

An optional LP refinement step (via ``mpax.r2HPDHG``) can tighten
the multiplier estimates used in the threshold computation, but is
not required for asymptotic correctness.

**Far-from-solution trust gate.**  Theorem 5 of Oberlin & Wright only
guarantees asymptotic exactness when the iterate is *close* to a
solution (so ``rho_bar`` is small).  Far from the solution, the raw
``threshold = (beta * rho_bar) ** sigma`` can grow large enough that
the test ``c_ineq_i <= threshold`` includes nearly every constraint,
producing an over-saturated working set that destroys QP convergence.
The implementation guards against this with a configurable
``trust_threshold`` on ``rho_bar``: when ``rho_bar > trust_threshold``,
LPEC-A returns an *empty* prediction so the QP active-set loop falls
back to its warm-start / cold-start path.

**Rank-aware size cap.**  As a secondary safety net, even when the
trust gate passes, the prediction is truncated so that at most
``n - m_eq - 1`` constraints are predicted active.  This preserves a
LICQ-like rank margin in the working-set Jacobian
``[A_eq; A_active]`` (which has at most ``n`` rows).  Selection
prioritises the most-violated / most-confident constraints
(smallest ``c_ineq_i``).

Sign convention
---------------
This module uses the ``slsqp-jax`` convention where
``c_ineq(x) >= 0`` means feasible.  The Oberlin & Wright paper uses
``c(x) <= 0`` for feasible constraints, so all formulas are adapted
by mapping ``c_paper_i = -c_ineq_i``.

References
----------
Oberlin, C. & Wright, S. J. (2005). "An accelerated Newton method
for equations with semismooth Jacobians and nonlinear complementarity
problems." *Mathematical Programming*, 117(1-2), 355-386.
"""

from typing import NamedTuple

import jax
import jax.numpy as jnp
from beartype import beartype
from jaxtyping import Array, Bool, Float, jaxtyped

from slsqp_jax.types import Scalar, Vector


[docs] class LPECAResult(NamedTuple): """LPEC-A active-set identification result. Attributes: predicted: Boolean mask of shape ``(m_ineq,)``. ``True`` means the constraint is predicted active. When ``valid=False`` this is the all-``False`` mask (the trust gate fired). valid: ``True`` when ``rho_bar`` is below the trust threshold, so the prediction is theoretically meaningful. ``False`` indicates the iterate is too far from the solution for LPEC-A to be reliable. capped: ``True`` when the rank-aware size cap truncated the prediction (the raw threshold predicted more than ``n - m_eq - 1`` constraints active, so the most confident entries were kept). rho_bar: The scalar proximity measure used for the trust gate. """ predicted: Bool[Array, " m_ineq"] valid: Bool[Array, ""] capped: Bool[Array, ""] rho_bar: Scalar
[docs] @jaxtyped(typechecker=beartype) def compute_rho_bar( c_ineq: Float[Array, " m_ineq"], c_eq: Float[Array, " m_eq"], grad: Vector, A_ineq: Float[Array, "m_ineq n"], A_eq: Float[Array, "m_eq n"], lambda_ineq: Float[Array, " m_ineq"], mu_eq: Float[Array, " m_eq"], ) -> Scalar: """Compute the LPEC-A proximity measure rho_bar. This implements Eq. 36 of Oberlin & Wright (2005), adapted to the ``c_ineq >= 0`` (feasible) sign convention. For feasible inequality constraints (``c_ineq_i > 0``), the contribution is ``sqrt(c_ineq_i * lambda_ineq_i)``; for violated constraints (``c_ineq_i <= 0``), it is ``-c_ineq_i`` (the violation magnitude). Args: c_ineq: Inequality constraint values at current point. c_eq: Equality constraint values at current point. grad: Objective gradient at current point. A_ineq: Inequality constraint Jacobian. A_eq: Equality constraint Jacobian. lambda_ineq: Multiplier estimates for inequality constraints. mu_eq: Multiplier estimates for equality constraints. Returns: The scalar proximity measure rho_bar >= 0. """ feasible = c_ineq > 0 ineq_contribution = jnp.sum( jnp.where( feasible, jnp.sqrt(jnp.maximum(c_ineq * lambda_ineq, 0.0)), -c_ineq, ) ) eq_contribution = jnp.sum(jnp.abs(c_eq)) stationarity_residual = grad if A_ineq.shape[0] > 0: stationarity_residual = stationarity_residual - A_ineq.T @ lambda_ineq if A_eq.shape[0] > 0: stationarity_residual = stationarity_residual - A_eq.T @ mu_eq stationarity_contribution = jnp.sum(jnp.abs(stationarity_residual)) return ineq_contribution + eq_contribution + stationarity_contribution
[docs] @jaxtyped(typechecker=beartype) def identify_active_set_lpeca( c_ineq: Float[Array, " m_ineq"], c_eq: Float[Array, " m_eq"], grad: Vector, A_ineq: Float[Array, "m_ineq n"], A_eq: Float[Array, "m_eq n"], lambda_ineq: Float[Array, " m_ineq"], mu_eq: Float[Array, " m_eq"], sigma: float = 0.9, beta: float | None = None, trust_threshold: float = 1.0, ) -> LPECAResult: """Predict the active inequality set using the LPEC-A threshold test. Applies Eq. 43 of Oberlin & Wright (2005), adapted to our sign convention. An inequality constraint ``i`` is predicted active when ``c_ineq_i <= (beta * rho_bar) ** sigma`` *and* ``rho_bar`` is below the trust threshold (otherwise the prediction is empty). The result is wrapped in an :class:`LPECAResult` so the caller can distinguish "no constraints predicted active" (a valid prediction) from "LPEC-A bypassed" (``valid=False``) or "size cap fired" (``capped=True``). Args: c_ineq: Inequality constraint values at current point. c_eq: Equality constraint values at current point. grad: Objective gradient at current point. A_ineq: Inequality constraint Jacobian. A_eq: Equality constraint Jacobian. lambda_ineq: Multiplier estimates for inequality constraints. mu_eq: Multiplier estimates for equality constraints. sigma: Threshold exponent (``sigma_bar`` in the paper). Must be in (0, 1). Default 0.9 per paper recommendation. beta: Threshold scaling factor. Default ``None`` uses the paper's recommendation ``1 / (m_ineq + n + m_eq)``. trust_threshold: Maximum ``rho_bar`` for which the LPEC-A prediction is trusted. When ``rho_bar > trust_threshold``, the predicted set is empty (the iterate is considered too far from the solution for the asymptotic guarantees to apply). Default ``1.0``. Returns: :class:`LPECAResult` containing the boolean prediction mask and the diagnostic flags ``valid`` / ``capped`` / ``rho_bar``. """ m_ineq = c_ineq.shape[0] n = grad.shape[0] m_eq = c_eq.shape[0] if beta is None: beta = 1.0 / max(m_ineq + n + m_eq, 1) rho_bar = compute_rho_bar(c_ineq, c_eq, grad, A_ineq, A_eq, lambda_ineq, mu_eq) # LPEC-A threshold (Oberlin & Wright 2005, Eq. 43). Far from the # solution ``rho_bar`` can be large, which inflates the threshold # so much that nearly every inequality is predicted active; the # resulting over-saturated working set typically causes the QP # equality solve to fail. Two layers of protection: # # 1. Trust gate: when ``rho_bar > trust_threshold`` (default 1.0) # the asymptotic correctness conditions of Theorem 5 are not # even approximately satisfied, so we return an empty prediction # and let the QP active-set loop fall back to warm-start. This # replaces the previous ``min(threshold, max|c_ineq|)`` clamp # which trivially evaluated to "all constraints active" because # every ``c_ineq_i <= max|c_ineq_i|`` by construction. # 2. Size cap: even when the trust gate passes, the prediction is # truncated so at most ``n - m_eq - 1`` constraints are # predicted active. This guarantees the working-set Jacobian # ``[A_eq; A_active]`` retains a LICQ-like rank margin (at most # ``n`` rows out of ``n + 1`` columns of slack). Truncation # keeps the most-violated constraints (smallest ``c_ineq_i``). threshold_raw = (beta * rho_bar) ** sigma valid = rho_bar <= trust_threshold threshold = jnp.where(valid, threshold_raw, 0.0) raw_predicted = (c_ineq <= threshold) & valid # Rank-aware size cap. ``n_dof`` is Python/static, so the cap is a # single static ``lax.top_k`` selection; we never materialise a full # rank vector and we never call nested ``argsort`` (that lowering # triggered a JAX/XLA verifier failure on CUDA+x64, see # ``jax-ml/jax`` issue #34096). n_dof_static = max(n - m_eq - 1, 1) if m_ineq == 0: return LPECAResult( predicted=raw_predicted, valid=valid, capped=jnp.array(False), rho_bar=rho_bar, ) n_dof = min(n_dof_static, m_ineq) # Score: most violated (smallest c_ineq_i) wins. Non-predicted # entries get -inf so their indices, if they happen to appear in the # top-k positions when fewer than ``n_dof`` rows are predicted, are # filtered out by the ``raw_predicted & selected_mask`` guard below. scores = jnp.where(raw_predicted, -c_ineq, -jnp.inf) # Static ``top_k`` selection of the ``n_dof`` largest scores. This # replaces the previous ``argsort(argsort(-scores))`` rank vector, # which lowered to a nested-sort pattern that the XLA # ``permutation_sort_simplifier`` pass rejected on CUDA+x64. _, top_indices = jax.lax.top_k(scores, n_dof) # ``astype(jnp.int32)`` normalises the scatter index dtype; it is # not part of the mathematical selection rule. selected_mask = ( jnp.zeros_like(raw_predicted).at[top_indices.astype(jnp.int32)].set(True) ) # ``raw_predicted &`` is the load-bearing guard: when fewer than # ``n_dof`` rows are predicted, ``top_k`` returns positions with # score ``-inf`` (non-predicted rows); the AND restores the # invariant that only predicted rows can survive the cap. capped_predicted = raw_predicted & selected_mask # Diagnostics are intentionally computed from ``raw_predicted`` so # ``capped_flag`` retains its "more eligible rows were predicted # than the LICQ-margin cap allowed" meaning, independent of the # ``top_k`` selection. predicted_count = jnp.sum(raw_predicted.astype(jnp.int32)) capped_flag = predicted_count > n_dof return LPECAResult( predicted=capped_predicted, valid=valid, capped=capped_flag, rho_bar=rho_bar, )
[docs] @jaxtyped(typechecker=beartype) def solve_lpeca_lp( c_ineq: Float[Array, " m_ineq"], c_eq: Float[Array, " m_eq"], grad: Vector, A_ineq: Float[Array, "m_ineq n"], A_eq: Float[Array, "m_eq n"], lambda_bound: float = 1e6, eps_abs: float = 1e-6, eps_rel: float = 1e-6, max_iter: int = 1000, ) -> tuple[Float[Array, " m_ineq"], Float[Array, " m_eq"]]: """Solve the LPEC-A LP to obtain tighter multiplier estimates. Solves the LP from Eq. 42 of Oberlin & Wright (2005), adapted to the ``c_ineq >= 0`` sign convention:: min_{lambda, mu, u, v} sum(c_ineq_i * lambda_i for feasible i) + e^T u + e^T v s.t. grad - A_ineq^T lambda - A_eq^T mu = u - v 0 <= lambda <= K_1 u, v >= 0 The LP is solved using ``mpax.r2HPDHG``, the reflected restarted Halpern PDHG algorithm (Lu & Yang, 2024), which achieves accelerated linear convergence on LP. Args: c_ineq: Inequality constraint values at current point. c_eq: Equality constraint values at current point. grad: Objective gradient at current point. A_ineq: Inequality constraint Jacobian. A_eq: Equality constraint Jacobian. lambda_bound: Upper bound ``K_1`` on lambda. Default 1e6. eps_abs: Absolute tolerance for the LP solver. eps_rel: Relative tolerance for the LP solver. max_iter: Maximum LP solver iterations. Returns: Tuple of (lambda_ineq, mu_eq) — the LP-optimal multiplier estimates for inequality and equality constraints. Raises: ImportError: If ``mpax`` is not installed. """ try: from mpax import create_lp, r2HPDHG except ImportError: raise ImportError( "The LPEC-A LP solve requires the 'mpax' package. " "Install it with:\n" " pip install slsqp-jax[extras]\n" "or\n" " uv sync --group extras" ) from None m_ineq = c_ineq.shape[0] m_eq = c_eq.shape[0] n = grad.shape[0] # Decision variables: z = [lambda (m_ineq), mu (m_eq), u (n), v (n)] n_vars = m_ineq + m_eq + 2 * n # Objective: sum(c_ineq_i * lambda_i for feasible i) + e^T u + e^T v # For feasible constraints (c_ineq > 0), the cost on lambda_i is c_ineq_i. # For violated constraints, the cost is 0 (they're not penalized in the LP # objective, only through the stationarity constraint). feasible = c_ineq > 0 obj_lambda = jnp.where(feasible, c_ineq, 0.0) obj_mu = jnp.zeros(m_eq) obj_u = jnp.ones(n) obj_v = jnp.ones(n) obj_vec = jnp.concatenate([obj_lambda, obj_mu, obj_u, obj_v]) # Equality constraint: grad - A_ineq^T lambda - A_eq^T mu - u + v = 0 # Rewritten as: [-A_ineq^T, -A_eq^T, -I, I] z = -grad lp_A_eq = jnp.concatenate( [ -A_ineq.T if m_ineq > 0 else jnp.zeros((n, 0)), -A_eq.T if m_eq > 0 else jnp.zeros((n, 0)), -jnp.eye(n), jnp.eye(n), ], axis=1, ) lp_b_eq = -grad # Bounds: 0 <= lambda <= K_1, -inf <= mu <= inf, u >= 0, v >= 0 lb = jnp.concatenate( [ jnp.zeros(m_ineq), jnp.full(m_eq, -jnp.inf), jnp.zeros(n), jnp.zeros(n), ] ) ub = jnp.concatenate( [ jnp.full(m_ineq, lambda_bound), jnp.full(m_eq, jnp.inf), jnp.full(n, jnp.inf), jnp.full(n, jnp.inf), ] ) lp = create_lp( c=obj_vec, A=lp_A_eq, b=lp_b_eq, G=jnp.zeros((0, n_vars)), h=jnp.zeros(0), l=lb, u=ub, ) solver = r2HPDHG( eps_abs=eps_abs, eps_rel=eps_rel, iteration_limit=max_iter, jit=True, verbose=False, ) result = solver.optimize(lp) # Extract multiplier estimates from primal solution lambda_opt = result.primal_solution[:m_ineq] mu_opt = result.primal_solution[m_ineq : m_ineq + m_eq] return lambda_opt, mu_opt
[docs] def compute_lpeca_active_set( c_ineq: Float[Array, " m_ineq"], c_eq: Float[Array, " m_eq"], grad: Vector, A_ineq: Float[Array, "m_ineq n"], A_eq: Float[Array, "m_eq n"], lambda_ineq: Float[Array, " m_ineq"], mu_eq: Float[Array, " m_eq"], sigma: float = 0.9, beta: float | None = None, trust_threshold: float = 1.0, use_lp: bool = False, lp_lambda_bound: float = 1e6, lp_eps: float = 1e-6, lp_max_iter: int = 1000, ) -> LPECAResult: """Compute the LPEC-A predicted active set, optionally refining multipliers. This is the main entry point for LPEC-A active set identification. It optionally solves the LPEC-A LP to obtain tighter multiplier estimates, then applies the threshold test (with the trust gate and rank-aware size cap from :func:`identify_active_set_lpeca`). Args: c_ineq: Inequality constraint values at current point. c_eq: Equality constraint values at current point. grad: Objective gradient at current point. A_ineq: Inequality constraint Jacobian. A_eq: Equality constraint Jacobian. lambda_ineq: Current multiplier estimates for inequalities. mu_eq: Current multiplier estimates for equalities. sigma: Threshold exponent (default 0.9). beta: Threshold scaling factor (default: paper recommendation). trust_threshold: Maximum ``rho_bar`` for which the prediction is trusted (see :func:`identify_active_set_lpeca`). use_lp: If True, solve the LPEC-A LP to refine multiplier estimates before the threshold test. Requires ``mpax``. lp_lambda_bound: Upper bound K_1 on lambda in the LP. lp_eps: Tolerance for the LP solver. lp_max_iter: Maximum LP solver iterations. Returns: :class:`LPECAResult` with the predicted active mask and diagnostic flags. """ if use_lp: lambda_ineq, mu_eq = solve_lpeca_lp( c_ineq=c_ineq, c_eq=c_eq, grad=grad, A_ineq=A_ineq, A_eq=A_eq, lambda_bound=lp_lambda_bound, eps_abs=lp_eps, eps_rel=lp_eps, max_iter=lp_max_iter, ) return identify_active_set_lpeca( c_ineq=c_ineq, c_eq=c_eq, grad=grad, A_ineq=A_ineq, A_eq=A_eq, lambda_ineq=lambda_ineq, mu_eq=mu_eq, sigma=sigma, beta=beta, trust_threshold=trust_threshold, )