Source code for slsqp_jax.diagnostics

"""Post-mortem diagnostics for SLSQP-JAX runs.

This sub-package layers a Python-driven debug runner, a per-step
scalar trajectory recorder, an artifact-eager signal pipeline, and a
small explicit playbook on top of the existing :class:`SLSQP` API.
Nothing here lives on the hot path: the additions are computed
post-hoc from existing state fields and never reached by
``step()`` / ``terminate()`` / verbose output.

Public entry points:

- :func:`debug_run` — re-execute a solver under a manual loop and
  return a :class:`DebugRunResult`.
- :func:`diagnose` — convenience wrapper that runs ``debug_run`` and
  builds a :class:`DebugReport`.
- :func:`capture_state_at_step` — re-run the solver to a specific
  step and return the live ``SLSQPState`` (with a mandatory
  reproducibility hash check).

The signal-evaluator tuples are exposed as
:data:`PER_STEP_EVALUATORS` and :data:`END_OF_RUN_EVALUATORS`; both
are empty in Phase 1 and populated by Phase 2.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional

from slsqp_jax.diagnostics.intercept import (
    DiagnosticContext,
    diagnose_minimize_like_scipy,
    diagnostic_run,
)
from slsqp_jax.diagnostics.playbook import (
    SCOPE_BY_TERMINATION,
    Diagnosis,
    confidence_for,
    evaluate_diagnoses,
    magnitude_for,
    signals_in_scope,
    signals_out_of_scope,
)
from slsqp_jax.diagnostics.records import DebugRunResult, StepSummary
from slsqp_jax.diagnostics.report import DebugReport
from slsqp_jax.diagnostics.runner import capture_state_at_step, debug_run
from slsqp_jax.diagnostics.signals import (
    END_OF_RUN_EVALUATORS,
    PER_STEP_EVALUATORS,
    SIGNAL_REGISTRY,
    EvalContext,
    Signal,
    SignalRegistration,
    register_evaluator,
)

if TYPE_CHECKING:
    from slsqp_jax.slsqp import SLSQP


[docs] def diagnose( solver: "SLSQP", fn: Callable, x0: Any, *, args: Any = None, max_steps: Optional[int] = None, has_aux: bool = False, ) -> DebugReport: """Run ``solver`` under a manual loop and return a :class:`DebugReport`. Convenience wrapper around :func:`debug_run` + :class:`DebugReport`. Uses the registered :data:`PER_STEP_EVALUATORS` and :data:`END_OF_RUN_EVALUATORS`; the report is always produced (even on a successful run), so the user can request a diagnosis when they suspect slow-but-converged behaviour, not just on hard failures. """ run = debug_run( solver, fn, x0, args=args, max_steps=max_steps, has_aux=has_aux, per_step_evaluators=PER_STEP_EVALUATORS, end_of_run_evaluators=END_OF_RUN_EVALUATORS, ) diagnoses = evaluate_diagnoses(run.fired_signals) report = DebugReport.from_run(run) report.diagnoses = list(diagnoses) return report
__all__ = [ "DebugReport", "DebugRunResult", "Diagnosis", "DiagnosticContext", "END_OF_RUN_EVALUATORS", "EvalContext", "PER_STEP_EVALUATORS", "SCOPE_BY_TERMINATION", "SIGNAL_REGISTRY", "Signal", "SignalRegistration", "StepSummary", "capture_state_at_step", "confidence_for", "debug_run", "diagnose", "diagnose_minimize_like_scipy", "diagnostic_run", "evaluate_diagnoses", "magnitude_for", "register_evaluator", "signals_in_scope", "signals_out_of_scope", ]