Source code for slsqp_jax.diagnostics.playbook

"""Playbook: termination-code scoping + multi-signal diagnoses.

The plan calls out two distinct pieces of post-fire interpretation:

1. **Termination-code scoping** (:data:`SCOPE_BY_TERMINATION`) — uses
   the granular ``slsqp_jax.RESULTS`` code from the failed run as
   *context* for the rest of the report.  Signals outside the scoped
   set for the run's termination code are still listed (never hidden)
   but visually de-prioritised under "less likely given the
   termination mode".  This is how the tool exploits the failure
   mode the user already arrived with.
2. **Multi-signal diagnoses** (:data:`RULES`) — when several signals
   fire together they sometimes form a textbook pattern that earns a
   named diagnosis.  Single-signal cases are *not* given a rule
   here; they are surfaced directly by the report renderer.

Confidence ranking lives here too: :func:`magnitude_for` derives the
dynamic ``magnitude`` axis from the signal's evidence dict (ratio to
threshold), and :func:`confidence_for` collapses
``(specificity, magnitude)`` to a single ``low`` / ``medium`` / ``high``
tag via the documented lookup table.

Phase 3 populates :data:`RULES` with three starter multi-signal
patterns; Phase 1 / 2 ship empty so the runner / report renderer
have the same call surface from day one.
"""

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional

from slsqp_jax.results import RESULTS

if TYPE_CHECKING:
    from slsqp_jax.diagnostics.signals import (
        Confidence,
        Magnitude,
        Signal,
        Specificity,
    )


# ---------------------------------------------------------------------------
# Scoping by termination code
# ---------------------------------------------------------------------------


def _enum_value(item: Any) -> Any:
    """Return the integer ``_value`` of an ``equinox.Enumeration`` item.

    Centralised here so the playbook can compare codes by value
    rather than by ``is``-identity (two items with the same value are
    equal but not ``is``-identical, depending on how / when they were
    constructed).  ``equinox.EnumerationItem`` is not hashable, so we
    must key the scope mapping by the integer value rather than the
    member directly.
    """
    raw = getattr(item, "_value", None)
    if raw is None:
        return None
    try:
        return int(raw)
    except (TypeError, ValueError):
        return None


# Signals considered "in scope" given each granular ``slsqp_jax.RESULTS``
# termination code.  Keyed by the integer ``_value`` of the
# corresponding ``RESULTS`` member because ``equinox.EnumerationItem``
# is not hashable.  A signal *outside* the scoped set is still
# listed in the report, but under a "less likely given the
# termination mode" sub-section — never hidden, just de-prioritised.
# An empty set (``RESULTS.nonlinear_max_steps_reached``,
# ``RESULTS.successful``) means "everything in scope" — no
# de-prioritisation applies.

SCOPE_BY_TERMINATION: dict[int, set[str]] = {
    _enum_value(RESULTS.merit_stagnation): {
        "lbfgs_conditioning_extreme",
        "multiplier_recovery_noise",
        "line_search_collapse",
        "merit_oscillation",
    },
    _enum_value(RESULTS.line_search_failure): {
        "lbfgs_conditioning_extreme",
        "line_search_collapse",
        "eq_jacobian_rank_deficient",
    },
    _enum_value(RESULTS.qp_subproblem_failure): {
        "qp_budget_or_pingpong",
        "eq_jacobian_rank_deficient",
        "lbfgs_conditioning_extreme",
        "lpeca_overpredicting",
    },
    _enum_value(RESULTS.iterate_blowup): {
        "lbfgs_conditioning_extreme",
        "merit_oscillation",
        "line_search_collapse",
        "divergence_rollback_triggered",
        "merit_penalty_explosion",
        "penalty_starvation",
    },
    _enum_value(RESULTS.infeasible): {
        "infeasible_termination",
        "eq_jacobian_rank_deficient",
        # Same divergence-rollback chain that latches ``iterate_blowup``
        # also routes here when the rolled-back ``best_x`` is itself
        # infeasible at termination.
        "divergence_rollback_triggered",
        "merit_penalty_explosion",
        "penalty_starvation",
        "lbfgs_conditioning_extreme",
        "line_search_collapse",
    },
    _enum_value(RESULTS.nonlinear_max_steps_reached): set(),
    _enum_value(RESULTS.successful): set(),
}


def _scope_for(termination_code: Any) -> Optional[set[str]]:
    """Look up :data:`SCOPE_BY_TERMINATION` by integer value."""
    target = _enum_value(termination_code)
    if target is None:
        return None
    return SCOPE_BY_TERMINATION.get(target)


def signals_in_scope(termination_code: Any, fired_names: set[str]) -> set[str]:
    """Partition ``fired_names`` by whether they lie in the scope for
    ``termination_code``.

    Returns the *in-scope* subset of ``fired_names``.  When
    ``termination_code`` has no entry in :data:`SCOPE_BY_TERMINATION`,
    or its entry is the empty set (the "everything in scope"
    sentinel used for ``successful`` and
    ``nonlinear_max_steps_reached``), every fired signal is treated
    as in-scope.
    """
    scope = _scope_for(termination_code)
    if scope is None or not scope:
        return set(fired_names)
    return {name for name in fired_names if name in scope}


def signals_out_of_scope(termination_code: Any, fired_names: set[str]) -> set[str]:
    """Complement of :func:`signals_in_scope` against ``fired_names``."""
    return set(fired_names) - signals_in_scope(termination_code, fired_names)


# ---------------------------------------------------------------------------
# Confidence: specificity x magnitude lookup
# ---------------------------------------------------------------------------

# Documented lookup table (see plan / README).  Indexed by
# ``(specificity, magnitude)`` strings.  Centralising it here means
# the report renderer / rule engine never has to re-derive it.
_CONFIDENCE_TABLE: dict[tuple[str, str], "Confidence"] = {
    ("specific", "marginal"): "medium",
    ("specific", "moderate"): "high",
    ("specific", "extreme"): "high",
    ("ambiguous", "marginal"): "low",
    ("ambiguous", "moderate"): "medium",
    ("ambiguous", "extreme"): "medium",
    ("generic", "marginal"): "low",
    ("generic", "moderate"): "low",
    ("generic", "extreme"): "low",
}


def confidence_for(specificity: "Specificity", magnitude: "Magnitude") -> "Confidence":
    """Collapse ``(specificity, magnitude)`` to a confidence tag.

    See the README / plan for the table.  Anything not in the table
    falls through to ``"low"`` rather than raising, so a future
    extension to specificity / magnitude vocabularies degrades
    gracefully.
    """
    return _CONFIDENCE_TABLE.get((specificity, magnitude), "low")


def magnitude_for(ratio: float) -> "Magnitude":
    """Bucket a ratio-to-threshold into the documented magnitude classes.

    ``ratio`` is the largest ratio of any single evidence value to its
    documented threshold.  Conventions:

    * ``< 10``    → ``"marginal"``
    * ``< 100``   → ``"moderate"``
    * otherwise   → ``"extreme"``

    Defaults to ``"marginal"`` when the ratio is non-finite (NaN, inf)
    or negative; the latter would indicate a programming error in the
    evaluator (the ratio should always be ``>= 1`` whenever the
    signal fires) but we degrade gracefully rather than raise.
    """
    try:
        r = float(ratio)
    except (TypeError, ValueError):
        return "marginal"
    if not (r == r) or r <= 0:  # NaN or non-positive
        return "marginal"
    if r < 10.0:
        return "marginal"
    if r < 100.0:
        return "moderate"
    return "extreme"


# ---------------------------------------------------------------------------
# Multi-signal diagnoses
# ---------------------------------------------------------------------------


[docs] @dataclass(frozen=True) class Diagnosis: """Named multi-signal diagnosis surfaced in the report. Attributes: name: Stable machine-readable identifier. cause: One-paragraph mechanism explanation. suggestions: Concrete starting-point fixes. related_signals: Names of the fired signals the rule combined. """ name: str cause: str suggestions: list[str] = field(default_factory=list) related_signals: list[str] = field(default_factory=list)
@dataclass(frozen=True) class Rule: """An explicit if-then mapping fired-signal-set → :class:`Diagnosis`. Attributes: name: Stable machine-readable identifier (mirrors the diagnosis ``name`` it produces). predicate: Callable ``set[str] -> bool`` that returns ``True`` iff the rule's signals are all present. build: Callable ``list[Signal] -> Diagnosis`` that constructs the diagnosis from the (filtered) fired signals. """ name: str predicate: Callable[[set[str]], bool] build: Callable[[list["Signal"]], Diagnosis] def _signals_by_name(signals: list["Signal"]) -> dict[str, "Signal"]: """Return ``{signal.name: signal}`` for the first occurrence of each name.""" out: dict[str, "Signal"] = {} for s in signals: if s.name not in out: out[s.name] = s return out def _build_stale_lbfgs_curvature(signals: list["Signal"]) -> Diagnosis: """Build the "stale L-BFGS curvature poisoning the QP" diagnosis.""" by_name = _signals_by_name(signals) related = sorted( n for n in ( "lbfgs_conditioning_extreme", "line_search_collapse", "merit_oscillation", ) if n in by_name ) return Diagnosis( name="stale_lbfgs_curvature", cause=( "The L-BFGS B0 diagonal has gone extremely ill-conditioned " "AND the resulting QP step is failing the line search or " "regressing the merit function. These two patterns " "together typically mean the curvature pairs in the L-BFGS " "history are stale: the gradient differences ``y_k`` no " "longer reflect the local Hessian at the current iterate. " "Soft-then-identity resets in the recovery chain may be " "discarding the only useful pair on each cycle " "(see ``AGENTS.md`` 'L-BFGS reset strategies')." ), suggestions=[ "Try ``use_exact_hvp_in_qp=True`` so the QP no longer " "depends on the L-BFGS approximation. The Newton-CG " "mode pays one HVP per CG iteration in exchange for " "complete decoupling from L-BFGS quality.", "If the problem has equality constraints, try " "``inner_solver=MinresQLPSolver()`` which solves the " "saddle-point KKT system directly and is less sensitive " "to L-BFGS corruption.", "Inspect the per-step ``lbfgs_skipped`` flags in the " "trajectory: a long unbroken streak of ``True`` is the " "signature of the post-identity-reset skip lock.", ], related_signals=related, ) def _build_active_set_churn(signals: list["Signal"]) -> Diagnosis: """Build the "active-set churn from rank-deficient working Jacobian" diagnosis.""" by_name = _signals_by_name(signals) related = sorted( n for n in ( "eq_jacobian_rank_deficient", "qp_budget_or_pingpong", ) if n in by_name ) return Diagnosis( name="active_set_churn", cause=( "The equality Jacobian is near rank-deficient AND the QP " "active-set loop is exhausting its budget or ping-ponging " "on the same constraint pair. When ``J_eq`` is rank-" "deficient at the current iterate, the working-set " "Jacobian inherits the deficiency and the active-set " "logic cycles trying to add and drop the same constraints. " "The ping-pong detector + ``mult_drop_floor`` are designed " "to short-circuit this, but a chronic rank deficiency " "needs a structural fix." ), suggestions=[ "Verify the equality constraints are linearly independent " "at the iterate (LICQ). If the constraints are " "syntactically distinct but algebraically degenerate, " "drop the redundant rows.", "Switch to ``inner_solver=MinresQLPSolver()`` which " "handles indefinite/singular saddle-point systems " "natively (no need for null-space projection).", "Raise ``mult_drop_floor`` so noise-flipped multipliers " "do not spuriously drop active constraints.", ], related_signals=related, ) def _build_penalty_starvation_cascade(signals: list["Signal"]) -> Diagnosis: """Build the "penalty-starved feasibility drift cascade" diagnosis.""" by_name = _signals_by_name(signals) related = sorted( n for n in ( "penalty_starvation", "merit_penalty_explosion", "divergence_rollback_triggered", "lbfgs_conditioning_extreme", "line_search_collapse", ) if n in by_name ) return Diagnosis( name="penalty_starvation_cascade", cause=( "The merit penalty ``rho`` stayed at its initial value " "while feasibility drifted (the Han-Powell directional-" "derivative test cannot trigger a ``rho`` increase when " "the ``f`` reduction alone satisfies it). Once " "feasibility decayed enough to demand a real correction, " "the merit-penalty update over-corrected by several " "orders of magnitude in a single step; that catch-up " "poisons subsequent QP directions and the line search " "collapses to the LS floor. This cascade typically " "ends with the best-iterate divergence rollback firing, " "leaving the user with an iterate slightly worse than " "any one the run actually visited. Started feasible " "but ended infeasible is the canonical signature." ), suggestions=[ "Perturb the initial iterate slightly off-feasibility " "(e.g. ``x0 + atol * sign_vector``) so the merit " "penalty update mechanism warms ``rho`` up before the " "feasibility drift accumulates.", "Bump the initial ``merit_penalty`` so the feasibility " "term has weight from step 1 and the directional-" "derivative test cannot ignore constraint violation.", "Try ``use_exact_hvp_in_qp=True`` so the QP step " "stays clean even if the merit penalty over-corrects " "later — this breaks the L-BFGS poisoning leg of the " "cascade.", ], related_signals=related, ) def _build_noise_floor_stall(signals: list["Signal"]) -> Diagnosis: """Build the "noise-floor stationarity stall" diagnosis.""" return Diagnosis( name="noise_floor_stationarity_stall", cause=( "The classical Lagrangian-gradient stationarity test " "stalled above ``rtol`` while the inner solver's " "projected-gradient norm already passed it. This is " "the textbook signature of multiplier-recovery noise " "contaminating ``∇L = ∇f − A^T λ`` (see ``AGENTS.md`` " "'Inexact stationarity disjunct'): the recovered ``λ`` " "carries ``O(eps · cond(A A^T))`` error and that error " "swamps the true stationarity residual at high " "precision." ), suggestions=[ "Set ``inner_solver=HRInexactSTCG(inner=ProjectedCGCholesky())`` " "so the noise-aware projected gradient is computed " "natively.", "Set ``use_inexact_stationarity=True`` so the " "convergence test admits the projected-gradient " "disjunct. The two together rescue this exact pattern.", "If you cannot change the inner solver, loosening " "``rtol`` so it sits above the multiplier-recovery noise " "floor is a valid (if less satisfying) workaround.", ], related_signals=["multiplier_recovery_noise"], ) # Single-signal cases do *not* earn a rule — the report renderer # surfaces them directly from the lone fired signal. The exception # is "noise-floor stationarity stall" which warrants its own # diagnosis because the recommended fix is concrete and non-obvious # from the signal name alone. RULES: list[Rule] = [ Rule( name="stale_lbfgs_curvature", predicate=lambda fired: ( "lbfgs_conditioning_extreme" in fired and ("line_search_collapse" in fired or "merit_oscillation" in fired) ), build=_build_stale_lbfgs_curvature, ), Rule( name="active_set_churn", predicate=lambda fired: ( "eq_jacobian_rank_deficient" in fired and "qp_budget_or_pingpong" in fired ), build=_build_active_set_churn, ), Rule( name="noise_floor_stationarity_stall", predicate=lambda fired: "multiplier_recovery_noise" in fired, build=_build_noise_floor_stall, ), Rule( name="penalty_starvation_cascade", predicate=lambda fired: ( "penalty_starvation" in fired and "merit_penalty_explosion" in fired ), build=_build_penalty_starvation_cascade, ), ] def evaluate_diagnoses(signals: list["Signal"]) -> list[Diagnosis]: """Run every registered :class:`Rule` against ``signals``. Returns the diagnoses produced by the rules whose predicate fired, in registration order. No deduplication or ranking happens here; the report renderer is responsible for sorting by confidence and applying the termination-code scope filter. """ fired = {s.name for s in signals} return [rule.build(signals) for rule in RULES if rule.predicate(fired)] # Re-export the granular result class for convenience: callers writing # their own scope filters / rules can compare against # ``RESULTS.merit_stagnation`` etc. without a second import. __all__ = [ "Diagnosis", "RESULTS", "RULES", "Rule", "SCOPE_BY_TERMINATION", "confidence_for", "evaluate_diagnoses", "magnitude_for", "signals_in_scope", "signals_out_of_scope", ]