Source code for slsqp_jax.diagnostics.runner

"""Manual-loop debug runner and ad-hoc state-snapshot helper.

This module replaces the on-device ``jax.lax.while_loop`` of
``optimistix.iterative_solve`` with a host-driven Python ``for`` loop
calling ``jit(step)``.  Every iteration ends with a single
device → host transfer that materialises a :class:`StepSummary` from
the live ``SLSQPState``.  The signal pipeline (Phase 2) layers
per-step + end-of-run evaluators on top of this skeleton.

Performance contract (load-bearing): the runner is a *debug* tool,
not a production loop.  On GPU the host sync per iteration can turn a
3 s production run into a 30-60 s diagnose run.  That cost is
acceptable because the alternative is the user reading verbose output
by eye for hours.  See ``AGENTS.md`` and the plan's "Performance
contract" section for the discipline that keeps it bounded (cheap
predicates over scalar :class:`StepSummary` fields, expensive
``build_artifacts`` only invoked when a signal actually fires).
"""

from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional

import jax

from slsqp_jax.diagnostics.records import DebugRunResult, StepSummary
from slsqp_jax.diagnostics.signals import EvalContext
from slsqp_jax.results import RESULTS

if TYPE_CHECKING:
    from slsqp_jax.slsqp import SLSQP
    from slsqp_jax.state import SLSQPState


def _wrap_fn_with_aux(fn: Callable, has_aux: bool) -> Callable:
    """Coerce ``fn`` to the ``(x, args) -> (value, aux)`` signature.

    ``SLSQP`` (mirroring ``optimistix.minimise(... has_aux=True)``)
    always calls the objective as ``f, aux = fn(x, args)``.  Users
    coming from a SciPy-style ``f = fn(x, *args)`` interface may not
    have packaged their function that way; this wrapper bridges the
    two without requiring them to think about it.
    """
    if has_aux:
        return fn

    def _wrapped(x: Any, args: Any) -> tuple[Any, None]:
        return fn(x, args), None

    return _wrapped


def _eval_fn_struct(fn: Callable, x0: Any, args: Any) -> tuple[Any, Any]:
    """Trace ``fn(x0, args)`` once to recover ``(f_struct, aux_struct)``.

    ``SLSQP.init`` requires ``ShapeDtypeStruct`` placeholders matching
    the abstract shape of ``fn``'s return value.  Optimistix derives
    them in its driver before calling ``init``; we replicate that here
    so the runner has the same calling convention as the stock loop.
    """
    return jax.eval_shape(fn, x0, args)


def _resolve_max_steps(solver: "SLSQP", max_steps: Optional[int]) -> int:
    """Pick the iteration budget for the manual loop.

    Defaults to ``solver.max_steps`` (i.e. the same budget
    ``optimistix.minimise`` would honour) so the runner reproduces the
    original run's terminal step count when no override is given.
    """
    if max_steps is None:
        return int(solver.max_steps)
    if max_steps <= 0:
        raise ValueError(f"max_steps must be positive, got {max_steps}")
    return int(max_steps)


def _build_jitted_solver_callbacks(
    solver: "SLSQP",
    fn: Callable,
    f_struct: Any,
    aux_struct: Any,
) -> tuple[Callable, Callable, Callable]:
    """Return jit-compiled ``(init, step, terminate)`` for the manual loop.

    Capturing ``solver`` and ``fn`` inside the closure means each
    invocation of ``debug_run`` triggers a fresh JIT compile (no
    cross-run cache hits).  That is acceptable: the runner is a
    diagnose-then-fix tool, not a hot path.  Re-using a previously
    jitted ``step`` across runs of *different* solvers would require
    retracing anyway because ``solver`` is part of the closure.
    """

    @jax.jit
    def jit_init(x0: Any, args: Any) -> tuple[Any, "SLSQPState"]:
        state = solver.init(fn, x0, args, {}, f_struct, aux_struct, frozenset())
        return x0, state

    @jax.jit
    def jit_step(
        y: Any, args: Any, state: "SLSQPState"
    ) -> tuple[Any, "SLSQPState", Any]:
        return solver.step(fn, y, args, {}, state, frozenset())

    @jax.jit
    def jit_terminate(y: Any, args: Any, state: "SLSQPState") -> tuple[Any, Any]:
        return solver.terminate(fn, y, args, {}, state, frozenset())

    return jit_init, jit_step, jit_terminate


[docs] def debug_run( solver: "SLSQP", fn: Callable, x0: Any, *, args: Any = None, max_steps: Optional[int] = None, has_aux: bool = False, per_step_evaluators: tuple = (), end_of_run_evaluators: tuple = (), ) -> DebugRunResult: """Run ``solver`` under a manual Python loop and return a :class:`DebugRunResult`. The loop reproduces what ``optimistix.minimise`` would do, iterating ``solver.init`` → ``solver.step`` → ``solver.terminate`` on each step. After every step the live ``SLSQPState`` is summarised into a :class:`StepSummary` and (when Phase 2 wires them) the per-step evaluators in ``per_step_evaluators`` are run against ``(state, summary, summaries_so_far)``. After the loop exits, ``end_of_run_evaluators`` are run against ``(final_state, summaries, final_result)``. Args: solver: The :class:`slsqp_jax.SLSQP` instance to run. fn: Objective callable. ``(x, args) -> value`` by default, or ``(x, args) -> (value, aux)`` when ``has_aux=True``. x0: Initial iterate. args: Extra payload threaded through ``fn`` and the constraint callables on ``solver``. max_steps: Optional iteration budget override. Defaults to ``solver.max_steps``. has_aux: Whether ``fn`` returns ``(value, aux)``. per_step_evaluators: Tuple of callables ``(state, summary, summaries) -> Signal | None`` invoked after each step. Empty by default (Phase 1 ships without signals). end_of_run_evaluators: Tuple of callables ``(final_state, summaries, result) -> Signal | None`` invoked once after the loop exits. Empty by default. Returns: :class:`DebugRunResult` with the per-step ``StepSummary`` trajectory, the terminal state, the granular and coarse termination codes, and any signals fired during the run. """ wrapped_fn = _wrap_fn_with_aux(fn, has_aux) f_struct, aux_struct = _eval_fn_struct(wrapped_fn, x0, args) jit_init, jit_step, jit_terminate = _build_jitted_solver_callbacks( solver, wrapped_fn, f_struct, aux_struct ) budget = _resolve_max_steps(solver, max_steps) ctx = EvalContext( solver=solver, rtol=float(solver.rtol), atol=float(solver.atol), max_steps=budget, ) y, state = jit_init(x0, args) initial_state = state summaries: list[StepSummary] = [] fired: list[Any] = [] seen_signal_names: set[str] = set() prev_state: "SLSQPState" = initial_state coarse_result = RESULTS.successful terminated_at_step = 0 max_steps_reached = True for k in range(budget): y, state, _aux = jit_step(y, args, state) summary = StepSummary.from_state(state, prev_state=prev_state) summaries.append(summary) for evaluator in per_step_evaluators: sig = evaluator(ctx, state, summary, summaries) if sig is not None and sig.name not in seen_signal_names: fired.append(sig) seen_signal_names.add(sig.name) done, coarse_result = jit_terminate(y, args, state) # ``done`` is a 0-d JAX boolean — the host-sync was already # paid by ``StepSummary.from_state``, so this read is free # in practice (the value is already on the host's L1 path). if bool(done): terminated_at_step = k max_steps_reached = False break prev_state = state else: # for-else: budget exhausted without `break` terminated_at_step = budget - 1 for evaluator in end_of_run_evaluators: sig = evaluator(ctx, state, summaries, coarse_result) if sig is not None and sig.name not in seen_signal_names: fired.append(sig) seen_signal_names.add(sig.name) return DebugRunResult( solver=solver, fn=fn, x0=x0, args=args, has_aux=has_aux, final_state=state, final_y=y, final_result=state.termination_code, coarse_result=coarse_result, summaries=summaries, terminated_at_step=terminated_at_step, max_steps_reached=max_steps_reached, fired_signals=fired, )
[docs] def capture_state_at_step( solver: "SLSQP", fn: Callable, x0: Any, step: int, *, args: Any = None, has_aux: bool = False, expected_summary: Optional[StepSummary] = None, ) -> "SLSQPState": """Re-run ``solver`` to step ``step`` and return the live ``SLSQPState``. This is the public ad-hoc inspection tool: signals already build their artifacts inline at the moment they fire, so most users will never call this. It exists for the case where the user wants to poke at a step their signals did not preserve a snapshot of. Args: solver: The :class:`slsqp_jax.SLSQP` instance to re-run. fn: Objective callable (same shape as for :func:`debug_run`). x0: Initial iterate. step: Target iteration to stop at (1-indexed; ``step=k`` returns the state immediately after the ``k``-th call to ``solver.step``). args: Extra payload threaded through ``fn``. has_aux: Whether ``fn`` returns ``(value, aux)``. expected_summary: Optional :class:`StepSummary` from the *original* :func:`debug_run` at the same step. When supplied, the recovered state's summary is hashed and compared against it; a mismatch raises ``RuntimeError``. This is the load-bearing reproducibility check that prevents the tool from silently lying about which iterate it is showing. Returns: The live ``SLSQPState`` immediately after step ``step``. Raises: ValueError: If ``step`` is non-positive. RuntimeError: If ``expected_summary`` is supplied and the recovered state's reproducibility digest does not match. """ if step <= 0: raise ValueError(f"step must be positive, got {step}") wrapped_fn = _wrap_fn_with_aux(fn, has_aux) f_struct, aux_struct = _eval_fn_struct(wrapped_fn, x0, args) jit_init, jit_step, _ = _build_jitted_solver_callbacks( solver, wrapped_fn, f_struct, aux_struct ) y, state = jit_init(x0, args) prev_state = state last_summary: Optional[StepSummary] = None for _ in range(step): y, state, _aux = jit_step(y, args, state) last_summary = StepSummary.from_state(state, prev_state=prev_state) prev_state = state if expected_summary is not None and last_summary is not None: recovered_digest = last_summary.reproducibility_digest() expected_digest = expected_summary.reproducibility_digest() if recovered_digest != expected_digest: diverging = _diff_summaries(last_summary, expected_summary) raise RuntimeError( "debug-run trajectory is not reproducible: " f"recovered digest {recovered_digest!r} != expected " f"{expected_digest!r} at step={step}. Diverging fields: " f"{diverging}. This usually indicates a JAX nondeterminism " "(e.g. GPU XLA reductions) between the original " "debug_run and this capture_state_at_step call." ) return state
def _diff_summaries(a: StepSummary, b: StepSummary) -> dict[str, tuple[Any, Any]]: """Return the fields of ``a`` and ``b`` whose values disagree. Used only by :func:`capture_state_at_step` when the reproducibility hash check fails, to give the user a concrete diagnostic. Tolerant of NaN/inf comparisons so the diff can be computed even on degenerate runs. """ out: dict[str, tuple[Any, Any]] = {} import dataclasses for f in dataclasses.fields(a): va = getattr(a, f.name) vb = getattr(b, f.name) # NaN-aware equality. if isinstance(va, float) and isinstance(vb, float): both_nan = (va != va) and (vb != vb) if both_nan: continue if va != vb: out[f.name] = (va, vb) return out __all__ = [ "debug_run", "capture_state_at_step", ]