Source code for slsqp_jax.diagnostics.records

"""Per-step scalar summaries and aggregate run records.

The diagnostics layer never stores per-step ``SLSQPState`` (or any
device-resident array) because doing so would add hundreds of MB of
memory for large ``n``.  Instead it records a tiny scalar
:class:`StepSummary` per iteration plus the *final* ``SLSQPState`` in
full.  All signal evaluators that need to inspect a non-final iterate
either compute their evidence at the moment the signal fires (when
the live state is still in scope) or trigger a re-run via
:func:`slsqp_jax.diagnostics.runner.capture_state_at_step`.

The fields on :class:`StepSummary` are deliberately host-resident
plain Python scalars so the diagnostics loop can branch on them
without forcing a device sync per access.  ``StepSummary.from_state``
performs the (single) device → host transfer per step.
"""

from __future__ import annotations

import dataclasses
import hashlib
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional

import jax.numpy as jnp

from slsqp_jax.merit import compute_merit
from slsqp_jax.results import RESULTS
from slsqp_jax.slsqp.termination import compute_mu_max

if TYPE_CHECKING:
    from slsqp_jax.slsqp import SLSQP
    from slsqp_jax.state import SLSQPDiagnostics, SLSQPState

# Number of decimal digits the reproducibility hash retains.  fp64
# arithmetic across two runs of an identical computation should agree
# bit-for-bit on the same hardware/jax-config; rounding to 12 digits
# absorbs the rare last-bit difference (e.g. when a fused-multiply-add
# is reordered) without inviting actual divergence.
_HASH_PRECISION_DECIMALS = 12


[docs] @dataclass(frozen=True) class StepSummary: """Host-resident scalar summary of a single SLSQP iteration. Every field is a plain Python ``int`` / ``float`` / ``bool`` so the diagnostics loop can branch on it without an extra device sync. The only device → host transfer happens once inside :meth:`StepSummary.from_state`. Attributes: step_count: Iteration number after this step (1-indexed). f_val: Objective value at the post-step iterate. merit: L1 merit value at the post-step iterate. last_alpha: Line-search step size accepted at this step. qp_iterations_total: Cumulative QP active-set iterations across all steps so far (matches ``state.qp_iterations``). qp_iterations_step: QP active-set iterations consumed by *this* step alone (delta of ``qp_iterations_total``). qp_converged: Whether the QP solver reported success this step. qp_real_failure: ``True`` iff the QP did not converge AND did not exhaust its iteration budget (the "real failure" distinction the L-BFGS reset chain uses; see ``AGENTS.md`` "Real-vs-budget QP failure distinction"). qp_reached_max_iter: Whether the QP active-set loop exhausted ``qp_max_iter`` this step (delta of ``diagnostics.n_qp_budget_exhausted``). qp_ping_ponged: Whether the QP ping-pong short-circuit fired this step (delta of ``diagnostics.n_qp_ping_pong``). ls_success: Whether the line search reported success. consecutive_qp_failures: Outer-loop failure-streak counter. consecutive_ls_failures: Outer-loop failure-streak counter. consecutive_zero_steps: Zero-step convergence-detection counter. grad_norm: Euclidean norm of the objective gradient at the post-step iterate. grad_lagrangian_norm: Euclidean norm of the Lagrangian gradient at the post-step iterate (used by the classical stationarity criterion). lagrangian_value: Value of the Lagrangian ``L = f - lambda_eq^T c_eq - mu_ineq^T c_ineq`` at the post-step iterate. rel_kkt: ``grad_lagrangian_norm / max(|L|, 1)`` — the legacy relative-stationarity proxy. Surfaced for backward- compatible diagnostics; note that the live convergence check now compares ``grad_lagrangian_norm / max(mu_max, 1)`` against ``rtol`` (filterSQP eqs. 5–6, see :func:`slsqp_jax.slsqp.termination.compute_mu_max`), so ``rel_kkt`` and the actual termination criterion may differ when ``|L|`` and ``mu_max`` disagree. kkt_scale: filterSQP ``mu_max`` denominator at this step. Because ``state.ineq_jac`` already contains the general inequality rows plus the unit-norm bound rows, this summary computes the same maximum by treating the whole inequality block as row-norm-scaled terms. kkt_ratio: Active relative-stationarity ratio used by the convergence test: ``grad_lagrangian_norm / max(kkt_scale, 1)``. gamma: Scalar L-BFGS initial-Hessian scaling. min_diag: Minimum entry of the L-BFGS per-variable diagonal. max_diag: Maximum entry of the L-BFGS per-variable diagonal. diag_kappa: ``max_diag / max(min_diag, 1e-30)``, the L-BFGS initial-Hessian condition-number proxy. lbfgs_count: Number of curvature pairs currently stored in the ring buffer. lbfgs_skipped: Whether the L-BFGS append skipped its curvature pair this step (delta of ``diagnostics.n_lbfgs_skips``). max_abs_mult_eq: ``max(|lambda_eq_ls|)`` at the post-step iterate (the stationarity-quality multiplier — the same vector exposed via ``Solution.stats["multipliers_eq"]``). max_abs_mult_ineq: ``max(|mu_ineq_ls|)`` at the post-step iterate. qp_vs_ls_multiplier_ratio: Maximum ``|λ_qp[i]| / max(|λ_ls[i]|, eps)`` across active equality + general-inequality rows. Surfaces "the QP multiplier was ``N x`` larger than the stationarity multiplier this step", which is the textbook signature of QP-multiplier inflation (poor L-BFGS conditioning + auto-scaling). Always ``1.0`` on the very first step where both are seeded from the same ``lstsq`` initialisation; settles around ``1.0`` near a clean KKT point and grows on noisy / ill-conditioned iterates. n_active_ineq: Number of active inequality constraints (general + bounds) at the post-step iterate. eq_jac_min_sv_est: Lower bound on the smallest singular value of ``J_eq`` estimated from the Cholesky of ``J_eq J_eq^T`` (cumulative low-water mark from ``diagnostics.eq_jac_min_sv_est``). projected_grad_norm: Norm of the inner solver's projected gradient (``HRInexactSTCG`` only; ``inf`` otherwise). merit_penalty: L1 merit penalty parameter ``rho`` after this step. max_eq_violation: ``max|c_eq|`` at the post-step iterate. max_ineq_violation: ``max(0, -c_ineq)`` at the post-step iterate. proj_residual_high_water: Cumulative high-water mark of the ``MinresQLPSolver`` M-metric projection residual. Always ``0.0`` for null-space solvers. diverging: Whether the best-iterate divergence rollback fired this step. blowup_count: Consecutive blowup events at this step. merit_regression_step: Whether ``compute_merit`` exceeded the best merit despite the line search reporting success (delta of ``diagnostics.n_merit_regressions``). """ step_count: int f_val: float merit: float last_alpha: float qp_iterations_total: int qp_iterations_step: int qp_converged: bool qp_real_failure: bool qp_reached_max_iter: bool qp_ping_ponged: bool ls_success: bool consecutive_qp_failures: int consecutive_ls_failures: int consecutive_zero_steps: int grad_norm: float grad_lagrangian_norm: float lagrangian_value: float rel_kkt: float gamma: float min_diag: float max_diag: float diag_kappa: float lbfgs_count: int lbfgs_skipped: bool max_abs_mult_eq: float max_abs_mult_ineq: float qp_vs_ls_multiplier_ratio: float n_active_ineq: int eq_jac_min_sv_est: float projected_grad_norm: float merit_penalty: float max_eq_violation: float max_ineq_violation: float proj_residual_high_water: float diverging: bool blowup_count: int merit_regression_step: bool kkt_scale: float = 0.0 kkt_ratio: float = float("inf")
[docs] @classmethod def from_state( cls, state: "SLSQPState", *, prev_state: Optional["SLSQPState"] = None, ) -> "StepSummary": """Materialise a :class:`StepSummary` from a live ``SLSQPState``. This is the *single* device → host transfer per debug-loop iteration. Everything the runner inspects for control flow below this point reads from the returned :class:`StepSummary`. ``prev_state`` is the ``SLSQPState`` from the previous iteration (or the initial-state from ``solver.init`` for step 1). It is used solely to compute single-step deltas of the cumulative diagnostics counters (``n_qp_budget_exhausted``, ``n_qp_ping_pong``, ``n_lbfgs_skips``, ``n_merit_regressions``) and ``state.qp_iterations``. Passing ``None`` treats every cumulative counter as starting from 0, which matches the behaviour of ``solver.init`` (which produces a state with all-zero diagnostics). """ diag: SLSQPDiagnostics = state.diagnostics prev_diag: Optional[SLSQPDiagnostics] = ( prev_state.diagnostics if prev_state is not None else None ) f_val = float(state.f_val) eq_val = state.eq_val ineq_val = state.ineq_val merit_penalty = float(state.merit_penalty) merit_arr = compute_merit(state.f_val, eq_val, ineq_val, state.merit_penalty) merit = float(merit_arr) m_eq = int(eq_val.shape[0]) m_ineq = int(ineq_val.shape[0]) max_eq_violation = float(jnp.max(jnp.abs(eq_val))) if m_eq > 0 else 0.0 max_ineq_violation = ( float(jnp.max(jnp.maximum(0.0, -ineq_val))) if m_ineq > 0 else 0.0 ) lagrangian_value = f_val if m_eq > 0: lagrangian_value -= float(jnp.dot(state.multipliers_eq_ls, eq_val)) if m_ineq > 0: lagrangian_value -= float(jnp.dot(state.multipliers_ineq_ls, ineq_val)) grad_lagrangian_norm = float(jnp.linalg.norm(state.grad_lagrangian)) rel_kkt = grad_lagrangian_norm / max(abs(lagrangian_value), 1.0) kkt_scale = float( compute_mu_max( grad_f=state.grad, eq_jac=state.eq_jac, ineq_jac_general=state.ineq_jac, mult_eq=state.multipliers_eq_ls, mult_ineq_general=state.multipliers_ineq_ls, mult_bound=jnp.zeros((0,), dtype=state.grad.dtype), ) ) kkt_ratio = grad_lagrangian_norm / max(kkt_scale, 1.0) diagonal = state.lbfgs_history.diagonal min_diag = float(jnp.min(diagonal)) max_diag = float(jnp.max(diagonal)) diag_kappa = max_diag / max(min_diag, 1e-30) max_abs_mult_eq = ( float(jnp.max(jnp.abs(state.multipliers_eq_ls))) if m_eq > 0 else 0.0 ) max_abs_mult_ineq = ( float(jnp.max(jnp.abs(state.multipliers_ineq_ls))) if m_ineq > 0 else 0.0 ) # qp_vs_ls_multiplier_ratio: max ratio across all eq + # general-inequality rows (the bound block agrees identically # between _qp and _ls so it does not contribute). Computed in # one pass over the concatenated multiplier vectors with a # numerical floor so a near-zero LS multiplier does not # produce a spurious infinity. ratio_eps = 1e-30 if m_eq > 0 or m_ineq > 0: mult_qp_concat_parts = [] mult_ls_concat_parts = [] if m_eq > 0: mult_qp_concat_parts.append(state.multipliers_eq_qp) mult_ls_concat_parts.append(state.multipliers_eq_ls) if m_ineq > 0: mult_qp_concat_parts.append(state.multipliers_ineq_qp) mult_ls_concat_parts.append(state.multipliers_ineq_ls) mult_qp_concat = jnp.concatenate(mult_qp_concat_parts, axis=0) mult_ls_concat = jnp.concatenate(mult_ls_concat_parts, axis=0) denom = jnp.maximum(jnp.abs(mult_ls_concat), ratio_eps) qp_vs_ls_multiplier_ratio = float(jnp.max(jnp.abs(mult_qp_concat) / denom)) else: qp_vs_ls_multiplier_ratio = 1.0 n_active_ineq = int(jnp.sum(state.prev_active_set.astype(jnp.int32))) qp_iterations_total = int(state.qp_iterations) prev_qp_iters = int(prev_state.qp_iterations) if prev_state is not None else 0 qp_iterations_step = qp_iterations_total - prev_qp_iters prev_n_budget = int(prev_diag.n_qp_budget_exhausted) if prev_diag else 0 qp_reached_max_iter = (int(diag.n_qp_budget_exhausted) - prev_n_budget) > 0 prev_n_pingpong = int(prev_diag.n_qp_ping_pong) if prev_diag else 0 qp_ping_ponged = (int(diag.n_qp_ping_pong) - prev_n_pingpong) > 0 prev_n_lbfgs_skips = int(prev_diag.n_lbfgs_skips) if prev_diag else 0 lbfgs_skipped = (int(diag.n_lbfgs_skips) - prev_n_lbfgs_skips) > 0 prev_n_merit_regressions = ( int(prev_diag.n_merit_regressions) if prev_diag else 0 ) merit_regression_step = ( int(diag.n_merit_regressions) - prev_n_merit_regressions ) > 0 qp_converged = bool(state.qp_converged) qp_real_failure = (not qp_converged) and (not qp_reached_max_iter) return cls( step_count=int(state.step_count), f_val=f_val, merit=merit, last_alpha=float(state.last_alpha), qp_iterations_total=qp_iterations_total, qp_iterations_step=qp_iterations_step, qp_converged=qp_converged, qp_real_failure=qp_real_failure, qp_reached_max_iter=qp_reached_max_iter, qp_ping_ponged=qp_ping_ponged, ls_success=bool(state.ls_success), consecutive_qp_failures=int(state.consecutive_qp_failures), consecutive_ls_failures=int(state.consecutive_ls_failures), consecutive_zero_steps=int(state.consecutive_zero_steps), grad_norm=float(jnp.linalg.norm(state.grad)), grad_lagrangian_norm=grad_lagrangian_norm, lagrangian_value=lagrangian_value, rel_kkt=rel_kkt, kkt_scale=kkt_scale, kkt_ratio=kkt_ratio, gamma=float(state.lbfgs_history.gamma), min_diag=min_diag, max_diag=max_diag, diag_kappa=diag_kappa, lbfgs_count=int(state.lbfgs_history.count), lbfgs_skipped=lbfgs_skipped, max_abs_mult_eq=max_abs_mult_eq, max_abs_mult_ineq=max_abs_mult_ineq, qp_vs_ls_multiplier_ratio=qp_vs_ls_multiplier_ratio, n_active_ineq=n_active_ineq, eq_jac_min_sv_est=float(diag.eq_jac_min_sv_est), projected_grad_norm=float(state.last_projected_grad_norm), merit_penalty=merit_penalty, max_eq_violation=max_eq_violation, max_ineq_violation=max_ineq_violation, proj_residual_high_water=float(diag.max_proj_residual), diverging=bool(state.diverging), blowup_count=int(state.blowup_count), merit_regression_step=merit_regression_step, )
[docs] def reproducibility_digest(self) -> str: """Return a short hex digest used by ``capture_state_at_step``. Two ``StepSummary`` instances produced by independent runs of the same ``(solver, x0, args)`` should hash to the same digest on the same hardware / ``jax.config`` settings. Floats are rounded to ``_HASH_PRECISION_DECIMALS`` decimals before hashing so a last-bit difference (e.g. from a fused-multiply-add reorder) does not false-positive on nondeterminism. """ parts: list[str] = [] for f in dataclasses.fields(self): val = getattr(self, f.name) if isinstance(val, float): if val != val or val in (float("inf"), float("-inf")): parts.append(f"{f.name}={val!r}") else: parts.append(f"{f.name}={round(val, _HASH_PRECISION_DECIMALS)!r}") else: parts.append(f"{f.name}={val!r}") joined = "|".join(parts).encode("utf-8") return hashlib.sha256(joined).hexdigest()[:16]
[docs] @dataclass class DebugRunResult: """Aggregate result of a manual-loop debug run. Carries everything a downstream signal evaluator or report renderer needs to do its job, plus the handles ``(solver, fn, x0, args, has_aux)`` required to re-execute the run for ad-hoc inspection via :func:`slsqp_jax.diagnostics.runner.capture_state_at_step`. Attributes: solver: The :class:`slsqp_jax.SLSQP` instance the run used. fn: The objective callable passed to ``debug_run``. x0: The initial iterate. args: Extra positional payload threaded through ``fn`` / constraint callables. has_aux: Whether ``fn`` returns ``(value, aux)``. final_state: The terminal ``SLSQPState`` after the manual loop exited. Carries the full ``SLSQPDiagnostics`` accumulator and every device array the end-of-run evaluators need. final_y: The terminal iterate ``y`` returned by ``step()``. final_result: The granular ``slsqp_jax.RESULTS`` termination code stored on ``final_state.termination_code``. coarse_result: The coarse ``optimistix.RESULTS`` code that ``terminate()`` returned (always a member of the parent enum, suitable for the optimistix driver-style check). summaries: Per-step :class:`StepSummary` records, one per iteration the loop actually executed. terminated_at_step: Index of the iteration where the loop exited (0-based; ``len(summaries) - 1`` on success). max_steps_reached: ``True`` iff the loop ran out of iterations without ``terminate()`` returning ``done``. fired_signals: Signals emitted by the per-step + end-of-run evaluators. Empty list when the diagnostics layer is run without signals (Phase 1 default). """ solver: "SLSQP" fn: Any x0: Any args: Any has_aux: bool final_state: "SLSQPState" final_y: Any final_result: Any coarse_result: Any summaries: list[StepSummary] terminated_at_step: int max_steps_reached: bool fired_signals: list[Any] = field(default_factory=list) @property def diagnostics(self) -> "SLSQPDiagnostics": """Convenience accessor for ``final_state.diagnostics``.""" return self.final_state.diagnostics @property def n_steps(self) -> int: """Number of iterations the loop actually executed.""" return len(self.summaries) @property def terminated_successfully(self) -> bool: """``True`` iff the run actually converged. Both conditions must hold: the granular ``final_result`` is :attr:`RESULTS.successful` *and* the manual loop did not run out of its iteration budget without ``terminate()`` returning ``done=True``. The second clause matters because :attr:`RESULTS.successful` is the *default* termination_code on a fresh ``SLSQPState`` (it means "no failure has latched yet") — so a truncated run with no failure flag set will report ``successful`` even though it never actually converged. """ if self.max_steps_reached: return False try: promoted = ( self.final_result if isinstance(self.final_result, RESULTS) else RESULTS.promote(self.final_result) ) return bool(promoted == RESULTS.successful) except Exception: # pragma: no cover -- defensive return False
__all__ = [ "StepSummary", "DebugRunResult", ]