Source code for slsqp_jax.scaling

"""Automatic problem scaling at the initial point.

This module implements gradient-based automatic scaling of the objective
and constraint functions evaluated at ``x0`` (IPOPT/KNITRO-style).  The
scaling is applied at the SLSQP boundary so the inner solver always sees
a balanced problem (``||grad||_inf ~ target_gradient`` per row); the
outputs are unscaled before being returned to the user.

The motivating failure mode is documented in the diagnostic notes for
the feasible-start divergence run: a ``||J_eq|| ~ 70`` vs
``||grad_f|| ~ 0.018`` magnitude mismatch (~4000x) drives a
``penalty_starvation -> merit_penalty_explosion -> divergence_rollback``
cascade that no amount of solver tuning can avoid.  Manually rescaling
the constraint fixes it; doing this automatically is the standard
recipe used by IPOPT, KNITRO, and SNOPT.

API surface (re-exported from :mod:`slsqp_jax`):

* :class:`ScaleFactors` -- frozen dataclass carrying the computed
  ``s_f``, ``s_eq``, ``s_ineq`` factors plus the user/internal ``atol``
  pair and book-keeping for skipped (near-zero gradient) rows.
* :class:`ScaledProblem` -- frozen dataclass bundling the scaled
  callables (``fn``, ``eq_constraint_fn``, ``ineq_constraint_fn``,
  derivative wrappers) plus the :class:`ScaleFactors` they were built
  from.
* :class:`ScalingConfig` -- frozen dataclass carrying the three
  scalars (``target_gradient``, ``max_factor``, ``grad_floor``) that
  parameterise :func:`compute_scale_factors_at_x0`.
* :func:`compute_scale_factors_at_x0` -- evaluate gradients at ``x0``
  and pick per-component scale factors.
* :func:`auto_scaled_problem` -- return a :class:`ScaledProblem` whose
  callables wrap the user's by ``s * f`` / ``s * c`` / ``s[:, None] * J``.
* :func:`unscale_solution` -- post-process a raw ``optx.Solution`` so
  multipliers and KKT residuals are returned in user units, with a
  ``scale_factors`` entry on ``sol.stats`` exposing the applied
  scaling.
* :func:`wrap_verbose_for_scaling` -- adapt the built-in verbose
  printer (or a user-supplied callable) so the per-step log shows
  user-unit values for ``f`` / ``|c|`` / ``|grad_f|`` / ``|grad_L|``
  / ``|d|``, and ``(scaled)``-labelled values for merit / rho / gamma
  / L-BFGS internals.
* :func:`auto_scaled_minimise` -- convenience wrapper around
  ``optx.minimise`` for lower-level users that mirrors the
  :func:`slsqp_jax.minimize_like_scipy` default-on auto-scaling path.
* :func:`resolve_scaling_mode` -- string-or-bool mode resolver that
  maps ``True``, ``False``, ``"balanced"``, ``"knitro"``, ``"ipopt"``,
  ``"aggressive"`` to a ``ScalingConfig``.

Mathematics
-----------

For each component (objective and each constraint row) with gradient
``g`` evaluated at ``x0``::

    norm = max(||g||_inf, grad_floor)
    s = clip(target_gradient / norm, eps, max_factor)

If ``||g||_inf < grad_floor``, the row is *skipped* (``s = 1.0`` and a
counter is incremented; a :class:`UserWarning` is emitted on the
objective).

Scaled wrappers::

    f_scaled(x)        = s_f * f(x)
    c_eq_scaled(x)     = s_eq * c_eq(x)             (element-wise)
    c_ineq_scaled(x)   = s_ineq * c_ineq(x)
    grad_scaled(x)     = s_f * grad(x)
    eq_jac_scaled(x)   = s_eq[:, None] * eq_jac(x)
    ineq_jac_scaled(x) = s_ineq[:, None] * ineq_jac(x)

The constraint HVP convention in
:mod:`slsqp_jax.slsqp.derivatives` is *per-component*: the wrapper
``ConstraintHVPFn`` returns an ``(m, n)`` stack of ``H_{c_i}(x) @ v``.
Per-row scaling therefore multiplies row ``i`` by ``s_eq[i]`` (or
``s_ineq[i]``); the contraction with multipliers happens later.

``atol`` compensation::

    s_min = min(min(s_eq), min(s_ineq), 1.0)
    atol_internal = atol_user * s_min

This guarantees ``|c_scaled[i]| <= atol_internal => |c_user[i]| <=
atol_user`` for the worst-scaled row, so the user-perceived
feasibility tolerance is preserved even when the inner solver sees a
shrunken constraint.

Output unscaling
----------------

Multipliers are unscaled by the recovered identity ``lambda_user =
(s / s_f) * lambda_scaled`` (derived from ``L_scaled = s_f * L_user``
and ``c_scaled = s * c_user``).  The Lagrangian gradient norm
unscales as ``||grad_L||_user = ||grad_L||_scaled / s_f``.  The merit
penalty stays in scaled units with an explicit ``(scaled units)``
note on ``sol.stats``.

This module is import-safe (no eager evaluation of user callables);
all derivative work happens inside
:func:`compute_scale_factors_at_x0`.
"""

from __future__ import annotations

import warnings
from collections.abc import Callable
from dataclasses import dataclass, replace
from typing import Any, Optional, Union

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optimistix as optx

from slsqp_jax.types import (
    ConstraintFn,
    ConstraintHVPFn,
    GradFn,
    HVPFn,
    JacobianFn,
)

# Smallest scale factor we will ever emit.  ``s = 0`` would zero the
# row and is mathematically catastrophic; ``eps`` lets the clip avoid
# emitting exactly zero when the user-supplied ``grad_floor`` is very
# small *and* the gradient is very large.
_SCALE_FLOOR = 1e-300


# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------


[docs] @dataclass(frozen=True) class ScalingConfig: """Parameters for :func:`compute_scale_factors_at_x0` (per-row) and :func:`compute_uniform_scale_factors_at_x0` (uniform). Attributes: target_gradient: Desired ``||grad||_inf`` after scaling. ``s = target_gradient / ||grad||`` (clipped). Under ``uniform=True`` this same target is consumed by both the objective scalar ``s_f`` and the single shared constraint scalar ``s_c``. max_factor: Bound on the scale factor. Under ``uniform=False`` (per-row modes) this is a one-sided amplification cap, so ``s in [eps, max_factor]``; ``max_factor=1.0`` means "shrink-only" (KNITRO/IPOPT). Under ``uniform=True`` the bound is **symmetric** so ``s in [1/max_factor, max_factor]`` and the value must satisfy ``max_factor >= 1.0``; passing ``max_factor=1.0`` under uniform disables scaling entirely and emits a UserWarning. The project default ``max_factor=1e3`` is well below typical AD relative-error noise floors (~1e-12) so amplification cannot promote roundoff to signal under any reasonable AD. grad_floor: Rows whose ``||grad||_inf`` falls below this value are *skipped* (left at ``s = 1.0``). ``1e-12`` is the right default: it is comfortably above ``eps`` so machine-zero is caught, but small enough that genuinely tiny-but-non-degenerate gradients still get scaled. uniform: When ``True``, apply a single shared scalar ``s_c`` across **all** constraint rows (equality and inequality unioned) preserving inter-row ratios; clip ``s_c`` and ``s_f`` symmetrically by ``max_factor``; set ``atol_internal = s_c * atol_user`` exactly (no ``min(., 1.0)`` cap, so ``atol_internal`` can exceed ``atol_user``). When ``False`` (the legacy default) each constraint row gets its own factor and ``atol_internal = atol_user * min(min(s_eq), min(s_ineq), 1.0)``. """ target_gradient: float = 1.0 max_factor: float = 1e3 grad_floor: float = 1e-12 uniform: bool = False
# Modes accepted by :func:`resolve_scaling_mode`. ``True`` resolves # to ``"uniform"`` and is the default for the user-facing # ``minimize_like_scipy`` / ``auto_scaled_minimise`` entry points. # ``"balanced"`` and the other per-row modes remain available for # users who want the old row-flattening behavior. _MODE_TABLE: dict[str, ScalingConfig] = { "uniform": ScalingConfig(target_gradient=1.0, max_factor=1e3, uniform=True), "balanced": ScalingConfig(target_gradient=1.0, max_factor=1e3), "knitro": ScalingConfig(target_gradient=1.0, max_factor=1.0), "ipopt": ScalingConfig(target_gradient=100.0, max_factor=1.0), "aggressive": ScalingConfig(target_gradient=1.0, max_factor=1e6), }
[docs] def resolve_scaling_mode( mode: Union[bool, str], *, target_gradient: Optional[float] = None, max_factor: Optional[float] = None, ) -> Optional[ScalingConfig]: """Map a user-facing ``auto_scale`` argument to a :class:`ScalingConfig`. Args: mode: ``True`` -> ``"uniform"`` (the default), ``False`` -> ``None`` (no scaling), or one of the string aliases ``"uniform"``, ``"balanced"``, ``"knitro"``, ``"ipopt"``, ``"aggressive"``. target_gradient: Optional explicit override of the mode's default target. ``None`` uses the mode default. max_factor: Optional explicit override of the mode's default cap. ``None`` uses the mode default. Returns: :class:`ScalingConfig` or ``None`` if ``mode`` is ``False``. Raises: ValueError: If ``mode`` is a string that is not one of the recognised aliases. TypeError: If ``mode`` is neither a bool nor a string. """ if mode is False: return None if mode is True: key = "uniform" elif isinstance(mode, str): key = mode.lower() else: raise TypeError( f"auto_scale must be a bool or one of {sorted(_MODE_TABLE)}, " f"got {type(mode).__name__}" ) if key not in _MODE_TABLE: raise ValueError( f"auto_scale={mode!r} is not recognised; expected one of " f"{sorted(_MODE_TABLE)} or a bool." ) base = _MODE_TABLE[key] if target_gradient is None and max_factor is None: return base return replace( base, target_gradient=base.target_gradient if target_gradient is None else float(target_gradient), max_factor=base.max_factor if max_factor is None else float(max_factor), )
# --------------------------------------------------------------------------- # ScaleFactors / ScaledProblem # ---------------------------------------------------------------------------
[docs] @dataclass(frozen=True) class ScaleFactors: """The factors :func:`compute_scale_factors_at_x0` (per-row) or :func:`compute_uniform_scale_factors_at_x0` (uniform) returned. Attributes: s_f: Scalar objective-scaling factor (``f_scaled = s_f * f``). ``1.0`` when no scaling was applied or the objective gradient was below ``grad_floor``. s_eq: Equality-constraint factors, shape ``(m_eq,)``. Empty array when ``m_eq == 0``. Per-row varying under ``uniform=False``; constant-valued (every entry equals the shared ``s_c``) under ``uniform=True``. s_ineq: Inequality-constraint factors, shape ``(m_ineq,)``. Empty array when ``m_ineq == 0``. Note that this is the *general* inequality count -- bound constraints are scaled separately (and trivially) by the bound-handling machinery. Per-row under ``uniform=False``; constant and equal to the same shared ``s_c`` as ``s_eq`` under ``uniform=True``. atol_user: The user-supplied feasibility tolerance. atol_internal: The compensated tolerance handed to the inner solver. Under ``uniform=False`` this is ``atol_user * min(min(s_eq), min(s_ineq), 1.0)`` (worst-row conservative). Under ``uniform=True`` this is ``s_c * atol_user`` exactly (can exceed ``atol_user`` when ``s_c > 1``). target_gradient: Echo of the :class:`ScalingConfig` field actually used. max_factor: Echo of the :class:`ScalingConfig` field actually used. grad_floor: Echo of the :class:`ScalingConfig` field actually used. n_skipped_eq: Number of equality rows whose ``||grad_eq[i]||_inf`` was below ``grad_floor``. Always ``0`` under ``uniform=True`` (uniform mode does not skip individual rows). n_skipped_ineq: Number of inequality rows whose ``||grad_ineq[i]||_inf`` was below ``grad_floor``. Always ``0`` under ``uniform=True``. skipped_obj: ``True`` when ``||grad_f||_inf`` was below ``grad_floor`` (a :class:`UserWarning` was emitted). uniform: ``True`` when the factors were produced by :func:`compute_uniform_scale_factors_at_x0` (single shared ``s_c`` across constraints, symmetric ``max_factor`` clipping, exact-equivalence ``atol_internal``). ``False`` for the per-row :func:`compute_scale_factors_at_x0`. """ s_f: float s_eq: jax.Array s_ineq: jax.Array atol_user: float atol_internal: float target_gradient: float max_factor: float grad_floor: float n_skipped_eq: int = 0 n_skipped_ineq: int = 0 skipped_obj: bool = False uniform: bool = False
[docs] @dataclass(frozen=True) class ScaledProblem: """Bundle of scaled callables + the :class:`ScaleFactors` they came from. Attributes mirror the ``SLSQP`` constructor's per-callable slots so the wrapper can be threaded through :func:`auto_scaled_minimise` or :func:`slsqp_jax.minimize_like_scipy` with minimal plumbing. The objective ``fn`` returns ``(s_f * value, aux)`` to match the ``has_aux=True`` convention used inside ``optx.minimise``. Constraint callables, Jacobians, and HVPs return scaled values directly. """ fn: Callable eq_constraint_fn: Optional[ConstraintFn] ineq_constraint_fn: Optional[ConstraintFn] obj_grad_fn: Optional[GradFn] eq_jac_fn: Optional[JacobianFn] ineq_jac_fn: Optional[JacobianFn] obj_hvp_fn: Optional[HVPFn] eq_hvp_fn: Optional[ConstraintHVPFn] ineq_hvp_fn: Optional[ConstraintHVPFn] factors: ScaleFactors
# --------------------------------------------------------------------------- # Scale-factor computation # --------------------------------------------------------------------------- def _scale_one( grad_inf: float, *, target_gradient: float, max_factor: float, grad_floor: float ) -> tuple[float, bool]: """Compute one scale factor from a gradient infinity-norm (per-row mode). Returns ``(s, skipped)`` where ``skipped`` is ``True`` iff the row was below the floor and ``s = 1.0`` was emitted. Uses the legacy one-sided clipping ``s in [_SCALE_FLOOR, max_factor]``. """ if grad_inf < grad_floor: return 1.0, True s = target_gradient / max(grad_inf, _SCALE_FLOOR) s = float(np.clip(s, _SCALE_FLOOR, max_factor)) return s, False def _scale_one_symmetric( grad_inf: float, *, target_gradient: float, max_factor: float, grad_floor: float ) -> tuple[float, bool]: """Compute one scale factor with **symmetric** clipping (uniform mode). Returns ``(s, skipped)`` where ``skipped`` is ``True`` iff ``grad_inf < grad_floor`` (then ``s = 1.0``). Otherwise ``s = clip(target_gradient / grad_inf, 1/max_factor, max_factor)``, i.e. the scale can amplify *or* shrink by up to ``max_factor``. The caller is expected to have validated ``max_factor >= 1.0`` so the symmetric interval is non-empty. """ if grad_inf < grad_floor: return 1.0, True s = target_gradient / max(grad_inf, _SCALE_FLOOR) lower = 1.0 / max_factor s = float(np.clip(s, lower, max_factor)) return s, False def _grad_inf_norm(arr: jax.Array) -> float: """Infinity norm of a (host-resident) gradient as a Python float.""" if arr.size == 0: return 0.0 return float(jnp.max(jnp.abs(arr))) def _row_inf_norms(jac: jax.Array) -> np.ndarray: """Infinity norm of each row of a Jacobian.""" if jac.size == 0: return np.zeros((0,), dtype=float) return np.asarray(jnp.max(jnp.abs(jac), axis=1))
[docs] def compute_scale_factors_at_x0( fn: Callable, x0: jax.Array, args: Any, has_aux: bool, *, eq_constraint_fn: Optional[ConstraintFn] = None, ineq_constraint_fn: Optional[ConstraintFn] = None, obj_grad_fn: Optional[GradFn] = None, eq_jac_fn: Optional[JacobianFn] = None, ineq_jac_fn: Optional[JacobianFn] = None, target_gradient: float = 1.0, max_factor: float = 1e3, grad_floor: float = 1e-12, atol_user: float = 1e-6, ) -> ScaleFactors: """Evaluate gradients at ``x0`` and return per-component scale factors. Args: fn: Objective. ``(x, args) -> value`` or ``(x, args) -> (value, aux)`` when ``has_aux=True``. x0: Initial iterate. args: Extra payload threaded through ``fn`` / constraints. has_aux: Whether ``fn`` returns ``(value, aux)``. eq_constraint_fn: Optional equality constraint function. ineq_constraint_fn: Optional inequality constraint function. obj_grad_fn: Optional user-supplied objective gradient. When ``None`` we fall back to ``jax.grad(fn)``. eq_jac_fn: Optional user-supplied equality Jacobian. ``jax.jacrev`` fallback. ineq_jac_fn: Optional user-supplied inequality Jacobian. ``jax.jacrev`` fallback. target_gradient: See :class:`ScalingConfig`. max_factor: See :class:`ScalingConfig`. grad_floor: See :class:`ScalingConfig`. atol_user: User-perceived feasibility tolerance. The returned :attr:`ScaleFactors.atol_internal` compensates for the constraint-shrinking factors. Returns: A :class:`ScaleFactors` instance. Notes: Emits a :class:`UserWarning` per component whose gradient is below ``grad_floor`` (objective and any constraint rows). Skipped rows keep ``s = 1.0``; the user is expected to either pick a different ``x0`` or pass ``auto_scale=False`` and provide their own scaling. """ x0 = jnp.asarray(x0, dtype=float) # Objective gradient. if obj_grad_fn is not None: grad_f = jnp.asarray(obj_grad_fn(x0, args)) else: if has_aux: def _scalar_fn(x: jax.Array) -> jax.Array: return fn(x, args)[0] else: def _scalar_fn(x: jax.Array) -> jax.Array: return fn(x, args) grad_f = jax.grad(_scalar_fn)(x0) grad_f_inf = _grad_inf_norm(jnp.asarray(grad_f)) s_f, skipped_obj = _scale_one( grad_f_inf, target_gradient=target_gradient, max_factor=max_factor, grad_floor=grad_floor, ) if skipped_obj: warnings.warn( "auto_scale: ||grad_f(x0)||_inf = " f"{grad_f_inf:.3e} is below grad_floor = " f"{grad_floor:.0e}; the objective will not be scaled " "(s_f = 1.0). Either pick a different starting point " "or pass auto_scale=False and supply your own scaling.", UserWarning, stacklevel=2, ) # Equality constraint Jacobian (per-row scaling). s_eq_list: list[float] = [] n_skipped_eq = 0 if eq_constraint_fn is not None: if eq_jac_fn is not None: eq_jac = jnp.asarray(eq_jac_fn(x0, args)) else: eq_jac = jax.jacrev(lambda x: eq_constraint_fn(x, args))(x0) eq_jac_arr = jnp.asarray(eq_jac) if eq_jac_arr.ndim == 1: eq_jac_arr = eq_jac_arr[None, :] row_norms = _row_inf_norms(eq_jac_arr) for i, r in enumerate(row_norms): s_i, skipped_i = _scale_one( float(r), target_gradient=target_gradient, max_factor=max_factor, grad_floor=grad_floor, ) if skipped_i: n_skipped_eq += 1 warnings.warn( f"auto_scale: ||grad eq[{i}](x0)||_inf = {r:.3e} is " f"below grad_floor = {grad_floor:.0e}; row will " "not be scaled (s_eq[i] = 1.0).", UserWarning, stacklevel=2, ) s_eq_list.append(s_i) s_eq = jnp.asarray(s_eq_list, dtype=float) # Inequality constraint Jacobian (per-row scaling). s_ineq_list: list[float] = [] n_skipped_ineq = 0 if ineq_constraint_fn is not None: if ineq_jac_fn is not None: ineq_jac = jnp.asarray(ineq_jac_fn(x0, args)) else: ineq_jac = jax.jacrev(lambda x: ineq_constraint_fn(x, args))(x0) ineq_jac_arr = jnp.asarray(ineq_jac) if ineq_jac_arr.ndim == 1: ineq_jac_arr = ineq_jac_arr[None, :] row_norms = _row_inf_norms(ineq_jac_arr) for i, r in enumerate(row_norms): s_i, skipped_i = _scale_one( float(r), target_gradient=target_gradient, max_factor=max_factor, grad_floor=grad_floor, ) if skipped_i: n_skipped_ineq += 1 warnings.warn( f"auto_scale: ||grad ineq[{i}](x0)||_inf = {r:.3e} is " f"below grad_floor = {grad_floor:.0e}; row will " "not be scaled (s_ineq[i] = 1.0).", UserWarning, stacklevel=2, ) s_ineq_list.append(s_i) s_ineq = jnp.asarray(s_ineq_list, dtype=float) # ``atol`` compensation: the worst-scaled constraint row drives # the tolerance shrink so that ``|c_scaled| <= atol_internal`` # implies ``|c_user| <= atol_user`` for every row. Capping the # min at ``1.0`` means well-scaled or amplified rows never # *loosen* the tolerance. pieces: list[float] = [1.0] if s_eq.size > 0: pieces.append(float(jnp.min(s_eq))) if s_ineq.size > 0: pieces.append(float(jnp.min(s_ineq))) s_min = min(pieces) atol_internal = float(atol_user) * float(min(s_min, 1.0)) return ScaleFactors( s_f=float(s_f), s_eq=s_eq, s_ineq=s_ineq, atol_user=float(atol_user), atol_internal=atol_internal, target_gradient=float(target_gradient), max_factor=float(max_factor), grad_floor=float(grad_floor), n_skipped_eq=n_skipped_eq, n_skipped_ineq=n_skipped_ineq, skipped_obj=skipped_obj, )
def _evaluate_obj_grad_inf_norm( fn: Callable, x0: jax.Array, args: Any, has_aux: bool, obj_grad_fn: Optional[GradFn], ) -> float: """Evaluate ``||grad f(x0)||_inf`` via user-supplied grad or AD fallback.""" if obj_grad_fn is not None: grad_f = jnp.asarray(obj_grad_fn(x0, args)) else: if has_aux: def _scalar_fn(x: jax.Array) -> jax.Array: return fn(x, args)[0] else: def _scalar_fn(x: jax.Array) -> jax.Array: return fn(x, args) grad_f = jax.grad(_scalar_fn)(x0) return _grad_inf_norm(jnp.asarray(grad_f)) def _evaluate_constraint_row_norms( constraint_fn: Optional[ConstraintFn], jac_fn: Optional[JacobianFn], x0: jax.Array, args: Any, ) -> np.ndarray: """Per-row inf-norms of a constraint Jacobian at ``x0`` (empty if absent).""" if constraint_fn is None: return np.zeros((0,), dtype=float) if jac_fn is not None: jac = jnp.asarray(jac_fn(x0, args)) else: jac = jax.jacrev(lambda x: constraint_fn(x, args))(x0) jac_arr = jnp.asarray(jac) if jac_arr.ndim == 1: jac_arr = jac_arr[None, :] return _row_inf_norms(jac_arr)
[docs] def compute_uniform_scale_factors_at_x0( fn: Callable, x0: jax.Array, args: Any, has_aux: bool, *, eq_constraint_fn: Optional[ConstraintFn] = None, ineq_constraint_fn: Optional[ConstraintFn] = None, obj_grad_fn: Optional[GradFn] = None, eq_jac_fn: Optional[JacobianFn] = None, ineq_jac_fn: Optional[JacobianFn] = None, target_gradient: float = 1.0, max_factor: float = 1e3, grad_floor: float = 1e-12, atol_user: float = 1e-6, ) -> ScaleFactors: """Evaluate gradients at ``x0`` and return **uniform** scale factors. Uniform mode applies a single shared scalar ``s_c`` to every constraint row (equality and inequality unioned) and an independent scalar ``s_f`` to the objective; both are clipped symmetrically by ``max_factor`` so they can amplify *or* shrink. The feasibility tolerance is propagated as ``atol_internal = s_c * atol_user`` exactly, so the inner solver's feasibility test corresponds 1-1 to the user-unit test ``|c_user[i]| <= atol_user`` regardless of which direction ``s_c`` moved. Args: fn: Objective. ``(x, args) -> value`` or ``(x, args) -> (value, aux)`` when ``has_aux=True``. x0: Initial iterate. args: Extra payload threaded through ``fn`` / constraints. has_aux: Whether ``fn`` returns ``(value, aux)``. eq_constraint_fn: Optional equality constraint function. ineq_constraint_fn: Optional inequality constraint function. obj_grad_fn: Optional user-supplied objective gradient. When ``None`` we fall back to ``jax.grad(fn)``. eq_jac_fn: Optional user-supplied equality Jacobian. ``jax.jacrev`` fallback. ineq_jac_fn: Optional user-supplied inequality Jacobian. ``jax.jacrev`` fallback. target_gradient: See :class:`ScalingConfig`. Consumed by both ``s_f`` and ``s_c`` derivations. max_factor: See :class:`ScalingConfig`. Must be ``>= 1.0``. ``max_factor == 1.0`` is legal but disables scaling and emits a :class:`UserWarning`. grad_floor: See :class:`ScalingConfig`. atol_user: User-perceived feasibility tolerance. The returned :attr:`ScaleFactors.atol_internal` is ``s_c * atol_user`` and may exceed ``atol_user`` when ``s_c > 1``. Returns: A :class:`ScaleFactors` instance with ``uniform=True``, ``s_eq = jnp.full((m_eq,), s_c)``, and ``s_ineq = jnp.full((m_ineq,), s_c)``. Raises: ValueError: If ``max_factor < 1.0`` (the symmetric interval ``[1/max_factor, max_factor]`` would be empty). Notes: Emits a :class:`UserWarning` when ``||grad_f(x0)||_inf`` is below ``grad_floor`` (``s_f = 1.0``), when every constraint row's gradient inf-norm is below ``grad_floor`` (``s_c = 1.0``), or when ``max_factor == 1.0`` (scaling disabled). No per-row warnings are emitted under uniform mode -- the cross-row max dominates the per-row check, and per-row magnitudes are intentionally preserved. """ if max_factor < 1.0: raise ValueError( f"compute_uniform_scale_factors_at_x0: max_factor must be >= 1.0 " f"for uniform mode (the symmetric clipping interval " f"[1/max_factor, max_factor] would otherwise be empty); got " f"max_factor={max_factor!r}. For shrink-only per-row scaling " "pass auto_scale='knitro' instead." ) if max_factor == 1.0: warnings.warn( "auto_scale: uniform mode with max_factor=1.0 disables " "scaling entirely (s_f and s_c are pinned to 1.0). " "Consider auto_scale='knitro' for shrink-only per-row " "behavior, or a larger max_factor to allow scaling.", UserWarning, stacklevel=2, ) x0 = jnp.asarray(x0, dtype=float) grad_f_inf = _evaluate_obj_grad_inf_norm(fn, x0, args, has_aux, obj_grad_fn) s_f, skipped_obj = _scale_one_symmetric( grad_f_inf, target_gradient=target_gradient, max_factor=max_factor, grad_floor=grad_floor, ) if skipped_obj: warnings.warn( "auto_scale: ||grad_f(x0)||_inf = " f"{grad_f_inf:.3e} is below grad_floor = " f"{grad_floor:.0e}; the objective will not be scaled " "(s_f = 1.0). Either pick a different starting point " "or pass auto_scale=False and supply your own scaling.", UserWarning, stacklevel=2, ) eq_row_norms = _evaluate_constraint_row_norms(eq_constraint_fn, eq_jac_fn, x0, args) ineq_row_norms = _evaluate_constraint_row_norms( ineq_constraint_fn, ineq_jac_fn, x0, args ) m_eq = int(eq_row_norms.shape[0]) m_ineq = int(ineq_row_norms.shape[0]) if m_eq + m_ineq == 0: # No constraints: s_c trivially 1.0, atol_internal = atol_user. s_c = 1.0 s_eq = jnp.zeros((0,), dtype=float) s_ineq = jnp.zeros((0,), dtype=float) atol_internal = float(atol_user) else: # ``max_row_norm`` over the union of equality + general-inequality # rows. This is the value that drives ``s_c``; preserving the # *ratio* between rows is the whole point of uniform mode, so # individual rows below ``grad_floor`` do not get special # treatment -- they ride the same ``s_c`` as everything else. if m_eq > 0 and m_ineq > 0: max_row_norm = float(np.maximum(eq_row_norms.max(), ineq_row_norms.max())) elif m_eq > 0: max_row_norm = float(eq_row_norms.max()) else: max_row_norm = float(ineq_row_norms.max()) if max_row_norm < grad_floor: warnings.warn( "auto_scale: all constraint Jacobian row inf-norms at " f"x0 are below grad_floor = {grad_floor:.0e} " f"(max = {max_row_norm:.3e}); constraints will not be " "scaled (s_c = 1.0). Either pick a different starting " "point or pass auto_scale=False and supply your own " "scaling.", UserWarning, stacklevel=2, ) s_c = 1.0 else: s_c_raw = target_gradient / max(max_row_norm, _SCALE_FLOOR) s_c = float(np.clip(s_c_raw, 1.0 / max_factor, max_factor)) s_eq = jnp.full((m_eq,), s_c, dtype=float) s_ineq = jnp.full((m_ineq,), s_c, dtype=float) atol_internal = float(s_c) * float(atol_user) return ScaleFactors( s_f=float(s_f), s_eq=s_eq, s_ineq=s_ineq, atol_user=float(atol_user), atol_internal=atol_internal, target_gradient=float(target_gradient), max_factor=float(max_factor), grad_floor=float(grad_floor), n_skipped_eq=0, n_skipped_ineq=0, skipped_obj=skipped_obj, uniform=True, )
# --------------------------------------------------------------------------- # Wrapping helpers # --------------------------------------------------------------------------- def _wrap_objective(fn: Callable, s_f: float, has_aux: bool) -> Callable: """Return a wrapped objective that scales the value by ``s_f``.""" s_f_arr = jnp.asarray(s_f, dtype=float) if has_aux: def wrapped(x: Any, args: Any) -> tuple[Any, Any]: value, aux = fn(x, args) return s_f_arr * jnp.asarray(value), aux else: def wrapped(x: Any, args: Any) -> tuple[Any, None]: value = fn(x, args) return s_f_arr * jnp.asarray(value), None return wrapped def _wrap_objective_grad(obj_grad_fn: GradFn, s_f: float) -> GradFn: """Return a wrapped gradient that scales by ``s_f``.""" s_f_arr = jnp.asarray(s_f, dtype=float) def wrapped(x: Any, args: Any) -> Any: return s_f_arr * jnp.asarray(obj_grad_fn(x, args)) return wrapped def _wrap_objective_hvp(obj_hvp_fn: HVPFn, s_f: float) -> HVPFn: """Return a wrapped HVP that scales by ``s_f``.""" s_f_arr = jnp.asarray(s_f, dtype=float) def wrapped(x: Any, v: Any, args: Any) -> Any: return s_f_arr * jnp.asarray(obj_hvp_fn(x, v, args)) return wrapped def _wrap_constraint_fn(constraint_fn: ConstraintFn, s_row: jax.Array) -> ConstraintFn: """Return a wrapped constraint function with element-wise row scaling.""" def wrapped(x: Any, args: Any) -> Any: c = jnp.asarray(constraint_fn(x, args)) return s_row * c return wrapped def _wrap_constraint_jac(jac_fn: JacobianFn, s_row: jax.Array) -> JacobianFn: """Return a wrapped Jacobian with per-row scaling.""" def wrapped(x: Any, args: Any) -> Any: J = jnp.asarray(jac_fn(x, args)) return s_row[:, None] * J return wrapped def _wrap_constraint_hvp(hvp_fn: ConstraintHVPFn, s_row: jax.Array) -> ConstraintHVPFn: """Return a wrapped per-component constraint HVP with per-row scaling. The constraint HVP convention in :mod:`slsqp_jax.slsqp.derivatives` is *per-component*: the callable returns an ``(m, n)`` stack whose row ``i`` is ``H_{c_i}(x) @ v``. Per-row scaling therefore multiplies row ``i`` of that stack by ``s_row[i]``. When the L-Bagrangian HVP contracts that stack with multipliers (``mu @ H``), the sum ``Σ mu_i * s_row[i] * H_{c_i}(x) v`` matches the natural HVP of the scaled constraint ``c_scaled = s_row * c``. """ def wrapped(x: Any, v: Any, args: Any) -> Any: H = jnp.asarray(hvp_fn(x, v, args)) return s_row[:, None] * H return wrapped
[docs] def auto_scaled_problem( fn: Callable, x0: jax.Array, args: Any, has_aux: bool, *, eq_constraint_fn: Optional[ConstraintFn] = None, ineq_constraint_fn: Optional[ConstraintFn] = None, obj_grad_fn: Optional[GradFn] = None, eq_jac_fn: Optional[JacobianFn] = None, ineq_jac_fn: Optional[JacobianFn] = None, obj_hvp_fn: Optional[HVPFn] = None, eq_hvp_fn: Optional[ConstraintHVPFn] = None, ineq_hvp_fn: Optional[ConstraintHVPFn] = None, scaling_config: ScalingConfig, atol_user: float = 1e-6, ) -> ScaledProblem: """Build a :class:`ScaledProblem` from the user's callables and ``x0``. Computes scale factors via either :func:`compute_scale_factors_at_x0` (when ``scaling_config.uniform`` is ``False`` — the legacy per-row behavior) or :func:`compute_uniform_scale_factors_at_x0` (when ``True`` — a single shared scalar across all constraint rows), wraps every supplied callable, and returns the bundle. Callables left as ``None`` stay as ``None`` -- the SLSQP solver will fall back to its AD paths, which automatically pick up the scaling from the wrapped ``fn`` / ``constraint_fn`` callables. The returned :attr:`ScaledProblem.fn` adheres to the ``has_aux=True`` convention (returning ``(value, aux)`` even when the user's ``fn`` returned just a value), matching what ``optimistix.minimise`` expects on the SLSQP path. """ if scaling_config.uniform: factors = compute_uniform_scale_factors_at_x0( fn=fn, x0=x0, args=args, has_aux=has_aux, eq_constraint_fn=eq_constraint_fn, ineq_constraint_fn=ineq_constraint_fn, obj_grad_fn=obj_grad_fn, eq_jac_fn=eq_jac_fn, ineq_jac_fn=ineq_jac_fn, target_gradient=scaling_config.target_gradient, max_factor=scaling_config.max_factor, grad_floor=scaling_config.grad_floor, atol_user=atol_user, ) else: factors = compute_scale_factors_at_x0( fn=fn, x0=x0, args=args, has_aux=has_aux, eq_constraint_fn=eq_constraint_fn, ineq_constraint_fn=ineq_constraint_fn, obj_grad_fn=obj_grad_fn, eq_jac_fn=eq_jac_fn, ineq_jac_fn=ineq_jac_fn, target_gradient=scaling_config.target_gradient, max_factor=scaling_config.max_factor, grad_floor=scaling_config.grad_floor, atol_user=atol_user, ) fn_scaled = _wrap_objective(fn, factors.s_f, has_aux) eq_fn_scaled = ( _wrap_constraint_fn(eq_constraint_fn, factors.s_eq) if eq_constraint_fn is not None and factors.s_eq.size > 0 else None ) ineq_fn_scaled = ( _wrap_constraint_fn(ineq_constraint_fn, factors.s_ineq) if ineq_constraint_fn is not None and factors.s_ineq.size > 0 else None ) obj_grad_scaled = ( _wrap_objective_grad(obj_grad_fn, factors.s_f) if obj_grad_fn is not None else None ) eq_jac_scaled = ( _wrap_constraint_jac(eq_jac_fn, factors.s_eq) if eq_jac_fn is not None and factors.s_eq.size > 0 else None ) ineq_jac_scaled = ( _wrap_constraint_jac(ineq_jac_fn, factors.s_ineq) if ineq_jac_fn is not None and factors.s_ineq.size > 0 else None ) obj_hvp_scaled = ( _wrap_objective_hvp(obj_hvp_fn, factors.s_f) if obj_hvp_fn is not None else None ) eq_hvp_scaled = ( _wrap_constraint_hvp(eq_hvp_fn, factors.s_eq) if eq_hvp_fn is not None and factors.s_eq.size > 0 else None ) ineq_hvp_scaled = ( _wrap_constraint_hvp(ineq_hvp_fn, factors.s_ineq) if ineq_hvp_fn is not None and factors.s_ineq.size > 0 else None ) return ScaledProblem( fn=fn_scaled, eq_constraint_fn=eq_fn_scaled, ineq_constraint_fn=ineq_fn_scaled, obj_grad_fn=obj_grad_scaled, eq_jac_fn=eq_jac_scaled, ineq_jac_fn=ineq_jac_scaled, obj_hvp_fn=obj_hvp_scaled, eq_hvp_fn=eq_hvp_scaled, ineq_hvp_fn=ineq_hvp_scaled, factors=factors, )
# --------------------------------------------------------------------------- # Output unscaling # ---------------------------------------------------------------------------
[docs] def unscale_solution(sol: optx.Solution, factors: ScaleFactors) -> optx.Solution: """Post-process ``sol`` so user-facing ``stats`` use unscaled units. The primary iterate ``sol.value`` lives in ``x``-space, which is *not* scaled by this module (variable scaling is out-of-scope -- see the deferred section of the auto-scaling plan). Multipliers and the Lagrangian gradient norm are converted from scaled to user units via ``lambda_user = (s / s_f) * lambda_scaled``. The merit penalty ``rho`` and its history live entirely in scaled units; we leave them as-is and add a ``merit_penalty_note`` entry flagging the unit. Args: sol: Raw solution returned by ``optimistix.minimise`` (or built post-hoc by the diagnostics layer). factors: The :class:`ScaleFactors` used to wrap the problem. Returns: A new :class:`optx.Solution` with the augmented ``stats`` dict. ``sol`` is not mutated. """ s_f = factors.s_f s_eq = factors.s_eq s_ineq = factors.s_ineq stats = dict(sol.stats) if sol.stats is not None else {} if "multipliers_eq" in stats and s_eq.size > 0: stats["multipliers_eq_user"] = (s_eq / s_f) * jnp.asarray( stats["multipliers_eq"] ) # Same scale recipe as the LS variants: general portion scales by # ``s_ineq / s_f``; the bound portion is already in user units # because bound rows are not scaled. if "multipliers_ineq" in stats and s_ineq.size > 0: mults = jnp.asarray(stats["multipliers_ineq"]) n_general = s_ineq.size if mults.shape[0] >= n_general: scale_vec = jnp.concatenate( [s_ineq / s_f, jnp.ones(mults.shape[0] - n_general)] ) stats["multipliers_ineq_user"] = scale_vec * mults else: # pragma: no cover -- defensive stats["multipliers_ineq_user"] = mults # QP-side multipliers (Han-Powell / LPEC-A / next-QP warm-start # view) — surfaced for advanced diagnostics so users can compare # them against the LS variant in user units. if "multipliers_eq_qp" in stats and s_eq.size > 0: stats["multipliers_eq_qp_user"] = (s_eq / s_f) * jnp.asarray( stats["multipliers_eq_qp"] ) if "multipliers_ineq_qp" in stats and s_ineq.size > 0: mults_qp = jnp.asarray(stats["multipliers_ineq_qp"]) n_general = s_ineq.size if mults_qp.shape[0] >= n_general: scale_vec_qp = jnp.concatenate( [s_ineq / s_f, jnp.ones(mults_qp.shape[0] - n_general)] ) stats["multipliers_ineq_qp_user"] = scale_vec_qp * mults_qp else: # pragma: no cover -- defensive stats["multipliers_ineq_qp_user"] = mults_qp if "final_grad_norm" in stats: stats["final_grad_norm_user"] = jnp.asarray(stats["final_grad_norm"]) / s_f if "final_lagrangian_grad_norm" in stats: stats["final_lagrangian_grad_norm_user"] = ( jnp.asarray(stats["final_lagrangian_grad_norm"]) / s_f ) # filterSQP eq. (5) ``μ_max`` scales linearly with ``s_f``: every # candidate in the max (``||∇f||``, ``|ν_i|``, ``||a_i||·|λ_i|``) # carries one factor of ``s_f`` under our scaling convention # (``∇f`` scales by ``s_f``, ``ν`` and ``λ`` carry compensating # ``s_f/s_c`` factors that combine with the per-row ``s_c`` on # ``||a_i||`` to leave a net ``s_f``). Unscaling is therefore a # plain division by ``s_f``. Keep the public ``kkt_scale`` key in # user units, matching the rest of the unscaled solution contract. if "kkt_scale" in stats: stats["kkt_scale"] = jnp.asarray(stats["kkt_scale"]) / s_f if "final_objective" in stats: stats["final_objective_user"] = jnp.asarray(stats["final_objective"]) / s_f stats["scale_factors"] = factors stats["merit_penalty_note"] = "scaled units" return optx.Solution( # ty: ignore[invalid-return-type] value=sol.value, result=sol.result, aux=sol.aux, stats=stats, state=sol.state, )
# --------------------------------------------------------------------------- # Verbose-printer adapter # --------------------------------------------------------------------------- # Keys (from the ``slsqp_verbose`` callback in # :mod:`slsqp_jax.slsqp._step_body`) that have a clean unscaled # equivalent. The mapping is to a recovery function ``(value, factors, # state) -> unscaled_value``. Keys not listed are passed through with # a ``(scaled)`` suffix on the label. _UNSCALABLE_KEYS_OBJ_DIVIDE = ( # ``f_scaled = s_f * f``: divide by ``s_f``. "objective", # ``|grad_f|_scaled = s_f * |grad_f|``: divide by ``s_f``. "grad_norm", # ``|grad_L|_scaled = s_f * |grad_L|``: divide by ``s_f``. "kkt_residual", # ``|grad_L|/|L|`` is dimensionless under uniform ``s_f`` scaling # of both numerator and denominator's ``L = f - lambda . c``; # the ratio is preserved exactly. No transform required. # ``|projected_grad|_scaled = s_f * |projected_grad|``. "proj_grad_norm", # filterSQP eq. (5) ``μ_max`` scales linearly with ``s_f`` (every # candidate in the max carries one ``s_f``); divide by ``s_f``. "kkt_scale", ) # Keys to flag with ``(scaled)`` in the printed label. Anything not # listed in the unscale set and not here is printed with its raw # label (e.g. step counters, booleans). _SCALED_LABEL_KEYS = ( "merit", "merit_delta", "penalty", "lbfgs_gamma", "lbfgs_sty", "lbfgs_relcurv", "lbfgs_diag_cond", # ``|c|`` is the max over scaled user constraints and (unscaled) # bound rows. No single ``s_f``-style transform recovers the # user-unit max from the scalar; flag it so users do not compare # this directly against ``atol_user``. "constraint_violation", ) def _make_user_unit_value(key: str, value: Any, factors: ScaleFactors) -> Any: """Compute the user-unit value for a known unscalable key.""" s_f = factors.s_f if s_f == 1.0: return value if key in _UNSCALABLE_KEYS_OBJ_DIVIDE: return value / s_f return value def _adapt_entry( key: str, entry: tuple, factors: ScaleFactors, needs_scaled_suffix: bool, ) -> tuple: """Rewrite a single ``(label, value[, fmt])`` tuple for the verbose call. ``needs_scaled_suffix`` is precomputed once by :func:`wrap_verbose_for_scaling` from ``factors`` (which never changes across the run) and threaded in here as a Python ``bool``. Computing it lazily inside this function would require reading ``factors.s_eq`` / ``factors.s_ineq`` (both ``jax.Array``) as concrete values, which fails under :func:`jax.jit` tracing of the enclosing ``step`` -- the verbose callback runs inside the jitted step (Optimistix's outer driver and ``debug_run``'s inner ``jit_step`` both jit ``step``). """ if len(entry) == 3: label, value, fmt = entry else: label, value = entry fmt = None if key in _UNSCALABLE_KEYS_OBJ_DIVIDE: new_value = _make_user_unit_value(key, value, factors) new_label = label elif key in _SCALED_LABEL_KEYS and needs_scaled_suffix: new_value = value new_label = f"{label}(s)" else: new_value = value new_label = label return (new_label, new_value, fmt) if fmt is not None else (new_label, new_value) def _scaling_is_active(factors: ScaleFactors) -> bool: """Return ``True`` iff at least one factor is non-trivial (``!= 1``). Reads ``factors.s_eq`` and ``factors.s_ineq`` (both ``jax.Array``) as concrete host values; must be called only outside any :func:`jax.jit` trace context. """ if factors.s_f != 1.0: return True if factors.s_eq.size > 0 and float(jnp.min(factors.s_eq)) != 1.0: return True if factors.s_ineq.size > 0 and float(jnp.min(factors.s_ineq)) != 1.0: return True return False
[docs] def wrap_verbose_for_scaling( user_verbose: Union[bool, Callable], factors: ScaleFactors, ) -> Callable: """Adapt a verbose callback to print user-unit values when scaling is on. Args: user_verbose: Either a boolean (``True`` / ``False``) or a user-supplied callable. Booleans select the built-in printer (``slsqp_verbose`` from :mod:`slsqp_jax.slsqp.verbose`); a callable is forwarded after the tuple values have been rewritten. factors: The :class:`ScaleFactors` to undo. Returns: A callback with the same ``(**kwargs)`` signature as :func:`slsqp_jax.slsqp.verbose.slsqp_verbose`. Quantities with a clean unscaled equivalent are converted to user units in place; quantities that live in scaled space (``merit``, ``rho``, L-BFGS internals) are printed with the suffix ``(s)`` appended to their label so the reader knows the unit. """ from slsqp_jax.slsqp.verbose import no_verbose, slsqp_verbose if user_verbose is False: # Return a no-op that nonetheless carries the marker attribute # so the diagnostics layer's introspection can still pick up # the factors from the solver's ``verbose`` slot. def wrapped_silent(**_kwargs: tuple) -> None: no_verbose(**_kwargs) # Setattr keeps the assignment on a single statement so the # ``ty: ignore`` directive lands on the unresolved-attribute # site (ty otherwise emits the warning on the assignment line # while the suppression sits on the value line after ruff's # multi-line split). setattr( # noqa: B010 -- need ty-ignore on a single line wrapped_silent, "_slsqp_scale_factors", factors, ) return wrapped_silent if user_verbose is True: target = slsqp_verbose else: target = user_verbose # Print a one-line preamble at module-import time so the user sees # the active factors before the first step's verbose line. We # print to ``stderr`` via ``warnings``-free path; the verbose # printer itself uses ``jax.debug.print``. s_f = factors.s_f if factors.uniform: # Under uniform mode ``s_eq`` and ``s_ineq`` are constant-valued # and equal to the same shared scalar ``s_c``; collapse the # display to a single value. When neither group has rows # ``s_c`` is the trivial 1.0. if factors.s_eq.size > 0: s_c = float(factors.s_eq[0]) elif factors.s_ineq.size > 0: s_c = float(factors.s_ineq[0]) else: s_c = 1.0 preamble = ( f"[auto-scale] (uniform) s_f={s_f:.3e}, s_c={s_c:.3e}, " f"atol_internal={factors.atol_internal:.3e} " f"(atol_user={factors.atol_user:.3e}); " "merit/rho/gamma/L-BFGS columns are in scaled units (suffix '(s)')." ) else: s_eq_min = float(jnp.min(factors.s_eq)) if factors.s_eq.size > 0 else 1.0 s_ineq_min = float(jnp.min(factors.s_ineq)) if factors.s_ineq.size > 0 else 1.0 preamble = ( f"[auto-scale] s_f={s_f:.3e}, " f"min(s_eq)={s_eq_min:.3e}, min(s_ineq)={s_ineq_min:.3e}; " "merit/rho/gamma/L-BFGS columns are in scaled units (suffix '(s)')." ) # Precompute the "needs (s) suffix" decision once on the host. # ``factors`` is invariant across the run, so reading its array # fields is safe here (eager) but would crash inside the jitted # step where ``_adapt_entry`` runs. See the docstring on # ``_adapt_entry`` for the full rationale. needs_scaled_suffix = _scaling_is_active(factors) # Stash the preamble + factors on the wrapper so the runner can # surface it at iteration 0; we cannot eagerly print because the # verbose callback is invoked under JIT trace. preamble_state = {"emitted": False} def wrapped(**kwargs: tuple) -> None: if not preamble_state["emitted"]: # Print the preamble exactly once. ``jax.debug.print`` is # the safe "JIT-friendly" mechanism, but we want this to # surface even under the host-driven debug runner so we # emit it here on the host side via stderr. _emit_preamble(preamble) preamble_state["emitted"] = True if isinstance(user_verbose, bool): adapted: dict[str, tuple] = {} for k, v in kwargs.items(): adapted[k] = _adapt_entry(k, v, factors, needs_scaled_suffix) target(**adapted) else: # User-supplied callable: pass scaled values through with # an explicit ``scale_factors`` keyword for downstream # consumption. We do *not* rewrite the tuples for them # (consensus from rounds 1-2 of the verbose-log debate: # avoid action-at-a-distance on user code). user_verbose(scale_factors=factors, **kwargs) # Stash the factors on the wrapper so the diagnostics layer can # surface them in the report without needing a separate # plumbing path. This is the load-bearing introspection hook # used by ``intercept._run_via_debug``. wrapped._slsqp_scale_factors = factors # ty: ignore[unresolved-attribute] return wrapped
def _emit_preamble(text: str) -> None: """Best-effort host-side preamble emission for the verbose wrapper. Routes through :mod:`sys.stderr` to avoid the project-wide ``no-print-statements`` pre-commit hook (which is intentionally strict; the verbose preamble is one of the few legitimate exceptions). """ import sys sys.stderr.write(text + "\n") sys.stderr.flush() # --------------------------------------------------------------------------- # Convenience minimiser # ---------------------------------------------------------------------------
[docs] def auto_scaled_minimise( fn: Callable, solver: Any, x0: Any, args: Any = None, *, auto_scale: Union[bool, str] = True, auto_scale_target_gradient: Optional[float] = None, auto_scale_max_factor: Optional[float] = None, has_aux: bool = False, options: Optional[dict] = None, max_steps: Optional[int] = 256, throw: bool = True, tags: frozenset = frozenset(), adjoint: Any = None, ) -> optx.Solution: """Convenience wrapper around ``optx.minimise`` with auto-scaling. Mirrors the :func:`slsqp_jax.minimize_like_scipy` default-on auto-scaling path for users who construct :class:`SLSQP` directly. Builds a :class:`ScaledProblem` from the user's callables, replaces the solver's eq/ineq/jac/hvp slots with the scaled wrappers, overrides the SLSQPConfig.atol with ``atol_internal``, and unscales the returned :class:`optx.Solution`. When ``auto_scale=False`` the call passes through to ``optx.minimise`` unchanged. Args: fn: Objective. Same convention as ``optx.minimise`` -- ``(x, args) -> value`` with ``has_aux=False`` or ``(x, args) -> (value, aux)`` with ``has_aux=True``. solver: An :class:`SLSQP` instance. The constraint, Jacobian, HVP, and verbose slots are read off and replaced with their scaled counterparts. x0: Initial iterate. args: Extra payload threaded through ``fn`` / constraints. auto_scale: ``True`` (default) -> ``"uniform"`` mode (a single shared ``s_c`` across all constraint rows + independent ``s_f`` for the objective, both symmetrically clipped by ``max_factor``). ``False`` -> no scaling (passthrough). String -> explicit mode name; pass ``"balanced"`` to recover the legacy per-row default. See :func:`resolve_scaling_mode` for the full table. auto_scale_target_gradient: Optional explicit target gradient override. Under ``uniform`` mode this is consumed by both ``s_f`` and ``s_c`` derivations. auto_scale_max_factor: Optional explicit max-factor override. Under ``uniform`` mode the symmetric bound requires ``max_factor >= 1.0`` (smaller raises ``ValueError``; ``== 1.0`` warns). has_aux: Whether ``fn`` returns ``(value, aux)``. options: Forwarded to ``optx.minimise``. max_steps: Forwarded to ``optx.minimise``. throw: Forwarded to ``optx.minimise``. tags: Forwarded to ``optx.minimise``. adjoint: Forwarded to ``optx.minimise`` when not ``None``. Returns: An :class:`optx.Solution`; when scaling was applied, the ``stats`` dict carries ``scale_factors`` and the ``_user`` suffixed fields documented on :func:`unscale_solution`. Under ``uniform`` mode the ``ScaleFactors`` instance has ``uniform=True`` and ``s_eq`` / ``s_ineq`` constant-valued and equal to the same shared ``s_c``. """ from slsqp_jax.slsqp import SLSQP cfg = resolve_scaling_mode( auto_scale, target_gradient=auto_scale_target_gradient, max_factor=auto_scale_max_factor, ) if cfg is None: kwargs = { "has_aux": has_aux, "max_steps": max_steps, "throw": throw, "tags": tags, } if adjoint is not None: kwargs["adjoint"] = adjoint return optx.minimise(fn, solver, x0, args, options, **kwargs) if not isinstance(solver, SLSQP): raise TypeError( "auto_scaled_minimise: solver must be an SLSQP instance to " "pick up the scaled constraint/Jacobian/HVP slots; got " f"{type(solver).__name__}. Pass auto_scale=False to use " "any other minimiser." ) user_atol = float(solver.atol) scaled = auto_scaled_problem( fn=fn, x0=x0, args=args, has_aux=has_aux, eq_constraint_fn=solver.eq_constraint_fn, ineq_constraint_fn=solver.ineq_constraint_fn, obj_grad_fn=solver.obj_grad_fn, eq_jac_fn=solver.eq_jac_fn, ineq_jac_fn=solver.ineq_jac_fn, obj_hvp_fn=solver.obj_hvp_fn, eq_hvp_fn=solver.eq_hvp_fn, ineq_hvp_fn=solver.ineq_hvp_fn, scaling_config=cfg, atol_user=user_atol, ) new_solver = _replace_solver_callables(solver, scaled) kwargs = { "has_aux": True, # scaled.fn always returns (value, aux) "max_steps": max_steps, "throw": throw, "tags": tags, } if adjoint is not None: kwargs["adjoint"] = adjoint sol = optx.minimise(scaled.fn, new_solver, x0, args, options, **kwargs) return unscale_solution(sol, scaled.factors)
def _replace_solver_callables(solver: Any, scaled: ScaledProblem) -> Any: """Return a copy of ``solver`` with constraint/derivative slots replaced. All replaced slots are static ``eqx.field`` instances, so neither :func:`equinox.tree_at` nor :func:`dataclasses.replace` can swap them without re-triggering ``__check_init__``. :func:`object.__setattr__` is the idiomatic escape hatch for static-field overrides on a frozen ``eqx.Module``. The user's solver is *not* mutated -- we operate on a shallow copy so the same solver can be re-used with different scaling settings (or no scaling at all) on subsequent calls. """ import copy as _copy solver = _copy.copy(solver) factors = scaled.factors new_verbose = wrap_verbose_for_scaling(solver.verbose, factors) # Override atol to atol_internal so the inner solver's # convergence checks match the user-perceived feasibility. We # leave rtol untouched: it tests ``|grad_L| / max(mu_max, 1)`` # (filterSQP eqs. 5–6), and every term in ``mu_max`` carries one # factor of ``s_f``. The hard ``max(., 1)`` floor remains in # internal units, so ``postprocess`` exposes the exact internal # dimensionless residual separately as ``stats["kkt_ratio"]``. from slsqp_jax.config import ToleranceConfig new_tol = ToleranceConfig( rtol=solver.config.tolerance.rtol, atol=factors.atol_internal, max_steps=solver.config.tolerance.max_steps, min_steps=solver.config.tolerance.min_steps, stagnation_tol=solver.config.tolerance.stagnation_tol, divergence_factor=solver.config.tolerance.divergence_factor, divergence_patience=solver.config.tolerance.divergence_patience, ) new_config = eqx.tree_at(lambda c: c.tolerance, solver.config, new_tol) overrides: dict[str, Any] = { "eq_constraint_fn": scaled.eq_constraint_fn, "ineq_constraint_fn": scaled.ineq_constraint_fn, "obj_grad_fn": scaled.obj_grad_fn, "eq_jac_fn": scaled.eq_jac_fn, "ineq_jac_fn": scaled.ineq_jac_fn, "obj_hvp_fn": scaled.obj_hvp_fn, "eq_hvp_fn": scaled.eq_hvp_fn, "ineq_hvp_fn": scaled.ineq_hvp_fn, "verbose": new_verbose, "config": new_config, } for key, value in overrides.items(): object.__setattr__(solver, key, value) # Also re-derive the cached derivative closures so they see the # newly-installed scaled callables instead of the originals. from slsqp_jax.slsqp.derivatives import ( build_grad_impl, build_hvp_contrib_impl, build_jacobian_impl, build_obj_hvp_impl, ) object.__setattr__(solver, "_grad_impl", build_grad_impl(scaled.obj_grad_fn)) object.__setattr__( solver, "_eq_jac_impl", build_jacobian_impl( user_jac=scaled.eq_jac_fn, constraint_fn=scaled.eq_constraint_fn, n_constraints=solver.n_eq_constraints, ), ) object.__setattr__( solver, "_ineq_jac_impl", build_jacobian_impl( user_jac=scaled.ineq_jac_fn, constraint_fn=scaled.ineq_constraint_fn, n_constraints=solver.n_ineq_constraints, ), ) object.__setattr__( solver, "_eq_hvp_contrib_impl", build_hvp_contrib_impl( user_hvp=scaled.eq_hvp_fn, constraint_fn=scaled.eq_constraint_fn, n_constraints=solver.n_eq_constraints, ), ) object.__setattr__( solver, "_ineq_hvp_contrib_impl", build_hvp_contrib_impl( user_hvp=scaled.ineq_hvp_fn, constraint_fn=scaled.ineq_constraint_fn, n_constraints=solver.n_ineq_constraints, ), ) object.__setattr__( solver, "_obj_hvp_impl", build_obj_hvp_impl( user_obj_hvp=scaled.obj_hvp_fn, use_exact_hvp_in_qp=new_config.qp.use_exact_hvp, ), ) return solver __all__ = [ "ScaleFactors", "ScaledProblem", "ScalingConfig", "auto_scaled_minimise", "auto_scaled_problem", "compute_scale_factors_at_x0", "compute_uniform_scale_factors_at_x0", "resolve_scaling_mode", "unscale_solution", "wrap_verbose_for_scaling", ]