Source code for slsqp_jax.qp.inequality

"""Inequality-only QP strategy.

Used when ``m_eq == 0``: there is no equality block to absorb / project,
just an active-set loop on inequality constraints.
"""

from __future__ import annotations

from collections.abc import Callable

import jax.numpy as jnp
from jaxtyping import Array, Bool, Float

from slsqp_jax.inner.base import AbstractInnerSolver
from slsqp_jax.inner.krylov import solve_unconstrained_cg
from slsqp_jax.qp._inner_check import inner_ok
from slsqp_jax.qp.active_set import ActiveSetInnerResult, run_active_set_loop
from slsqp_jax.state import QPSolverResult, QPState
from slsqp_jax.types import Scalar, Vector


[docs] def solve_qp_inequality( hvp_fn: Callable[[Vector], Vector], g: Vector, A_ineq: Float[Array, "m_ineq n"], b_ineq: Float[Array, " m_ineq"], m_ineq: int, max_iter: int, max_cg_iter: int, tol: Scalar | float, expand_factor: float, initial_active_set: Bool[Array, " m_ineq"] | None, kkt_residual: Scalar | float, inner_solver: AbstractInnerSolver, precond_fn: Callable[[Vector], Vector] | None = None, cg_tol: Scalar | float | None = None, cg_regularization: float = 1e-6, predicted_active_set: Bool[Array, " m_ineq"] | None = None, use_expand: bool = True, mult_drop_floor: float = 1e-6, ping_pong_threshold: int = 2**31 - 1, ) -> QPSolverResult: """Solve the QP with inequality constraints only (no equality block).""" inner_cg_tol: Scalar | float = cg_tol if cg_tol is not None else tol # Initial unconstrained solve. d_init, _ = solve_unconstrained_cg( hvp_fn, g, max_cg_iter, inner_cg_tol, precond_fn=precond_fn, cg_regularization=cg_regularization, ) init_inner_failure = ~jnp.isfinite(d_init).all() kkt_res = jnp.asarray(kkt_residual, dtype=jnp.float64) base_tol = tol + jnp.minimum(kkt_res, 1.0) * tol residuals_init = A_ineq @ d_init - b_ineq if predicted_active_set is not None: init_active = predicted_active_set | (residuals_init < -base_tol) elif initial_active_set is not None: init_active = initial_active_set | (residuals_init < -base_tol) else: init_active = residuals_init < -base_tol init_converged = ~jnp.any(init_active) init_state = QPState( # ty: ignore[invalid-return-type] d=d_init, active_set=init_active, multipliers_eq=jnp.zeros((0,)), multipliers_ineq=jnp.zeros((m_ineq,)), iteration=jnp.array(0), converged=init_converged, any_inner_failure=init_inner_failure, last_add_idx=jnp.array(-1), last_drop_idx=jnp.array(-1), ping_pong_count=jnp.array(0), ping_ponged=jnp.array(False), proj_residual=jnp.asarray(0.0, dtype=jnp.float64), n_proj_refinements=jnp.asarray(0), projected_grad_norm=jnp.asarray(jnp.inf, dtype=jnp.float64), ) tau = base_tol * expand_factor / jnp.maximum(max_iter, 1) effective_tau = tau if use_expand else 0.0 drop_floor = jnp.asarray(mult_drop_floor, dtype=jnp.float64) empty_eq = jnp.zeros((0,)) def _inner_solve(state: QPState, _working_tol: Scalar) -> ActiveSetInnerResult: result = inner_solver.solve( hvp_fn, g, A_ineq, b_ineq, state.active_set, precond_fn=precond_fn, adaptive_tol=cg_tol, ) return ActiveSetInnerResult( d=result.d, multipliers_eq=empty_eq, multipliers_ineq=result.multipliers, inner_failed=~inner_ok(result), proj_residual=result.proj_residual.astype(jnp.float64), n_proj_refinements=result.n_proj_refinements, projected_grad_norm=result.projected_grad_norm.astype(jnp.float64), ) final_state = run_active_set_loop( init_state=init_state, # ty: ignore[invalid-argument-type] inner_solve_fn=_inner_solve, A_ineq=A_ineq, b_ineq=b_ineq, max_iter=max_iter, base_tol=base_tol, effective_tau=effective_tau, drop_floor=drop_floor, ping_pong_threshold=ping_pong_threshold, ) reached_max_iter = final_state.iteration >= max_iter final_converged = final_state.converged & ~reached_max_iter final_working_tol = base_tol + final_state.iteration * effective_tau return QPSolverResult( d=final_state.d, multipliers_eq=final_state.multipliers_eq, multipliers_ineq=final_state.multipliers_ineq, active_set=final_state.active_set, converged=final_converged, iterations=final_state.iteration, ping_ponged=final_state.ping_ponged, reached_max_iter=reached_max_iter, final_working_tol=jnp.asarray(final_working_tol, dtype=jnp.float64), proj_residual=final_state.proj_residual, n_proj_refinements=final_state.n_proj_refinements, projected_grad_norm=final_state.projected_grad_norm, )
__all__ = ["solve_qp_inequality"]