Source code for slsqp_jax.slsqp.verbose

"""Verbose-output callbacks for the SLSQP outer loop.

Two callables are exposed:

* :func:`slsqp_verbose` — print one line per SLSQP step, with each
  field formatted by an optional ``fmt_spec`` provided as the third
  tuple element.  Wraps :func:`jax.debug.print`, so it is safe to call
  inside a JAX-traced ``step``.
* :func:`no_verbose` — no-op alternative used when ``verbose=False``
  is requested at construction time.

Both are kept tiny on purpose: the format specifiers are stripped at
class-construction time so that PyTree equality (and therefore
``optimistix`` JIT cache hits) is preserved across runs that differ
only in the verbose printer payload.
"""

from __future__ import annotations

from typing import Any

import jax


[docs] def slsqp_verbose(**kwargs: tuple) -> None: """Default verbose callback with per-field format specifiers. Each kwarg value is either ``(label, value)`` or ``(label, value, fmt_spec)``. The ``fmt_spec`` string (e.g. ``".3e"``) is inserted into the ``jax.debug.print`` format placeholder. """ string_pieces: list[str] = [] arg_pieces: list[Any] = [] for entry in kwargs.values(): if len(entry) == 3: name, value, _fmt = entry string_pieces.append(f"{name}: {{:{_fmt}}}") else: name, value = entry string_pieces.append(f"{name}: {{}}") arg_pieces.append(value) if string_pieces: jax.debug.print(", ".join(string_pieces), *arg_pieces)
[docs] def no_verbose(**_kwargs: tuple) -> None: """No-op verbose callback."""
__all__ = ["no_verbose", "slsqp_verbose"]