Source code for slsqp_jax.slsqp.hvp

"""Lagrangian Hessian-vector product factories.

The QP subproblem needs ``v -> B v`` where ``B`` is some
positive-definite approximation to the Lagrangian Hessian.  Two
strategies are supported:

* **Frozen L-BFGS** (default): the L-BFGS history at the *start* of
  the SLSQP step is treated as constant for the duration of the inner
  CG loop.  ``v -> B_k v`` is implemented via the compact
  representation in :func:`slsqp_jax.hessian.lbfgs_hvp`.
* **Newton-CG** (``QPConfig.use_exact_hvp_in_qp = True``): the exact
  Lagrangian HVP at the current iterate is used directly inside the
  CG loop.  Each CG step costs one forward-over-reverse AD pass; on
  ill-conditioned problems this dramatically accelerates convergence.

The L-BFGS history is still updated regardless of which mode is
active, because (a) it is the canonical source for the preconditioner
and (b) it is the fallback in case a future revision skips an exact
HVP evaluation.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import Any, cast

from jaxtyping import Array, Float

from slsqp_jax.hessian import LBFGSHistory, lbfgs_hvp
from slsqp_jax.types import Vector


[docs] def build_lbfgs_lagrangian_hvp( lbfgs_history: LBFGSHistory, ) -> Callable[[Vector], Vector]: """Closure over the *frozen* L-BFGS history.""" def hvp(v: Vector) -> Vector: return lbfgs_hvp(lbfgs_history, v) return hvp
[docs] def build_exact_lagrangian_hvp( *, fn: Callable, y: Vector, args: Any, multipliers_eq: Float[Array, " m_eq"], multipliers_ineq: Float[Array, " m_ineq"], obj_hvp_impl: Callable | None, eq_hvp_contrib_impl: Callable, ineq_hvp_contrib_impl: Callable, n_ineq_general: int, ) -> Callable[[Vector], Vector]: """Build the exact Lagrangian HVP at the current iterate. ``H_L v = H_f v − Σ λ_i H_{c_eq_i} v − Σ μ_j H_{c_ineq_j} v``. The dispatch to user-supplied vs AD-computed contribution callables is resolved upstream in :func:`slsqp_jax.slsqp.derivatives.make_derivative_closures`. """ obj_hvp = cast(Callable, obj_hvp_impl) eq_hvp = eq_hvp_contrib_impl ineq_hvp = ineq_hvp_contrib_impl def lagrangian_hvp(v: Vector) -> Vector: obj_val = obj_hvp(fn, y, v, args) eq_val = eq_hvp(y, v, args, multipliers_eq) ineq_val = ineq_hvp(y, v, args, multipliers_ineq[:n_ineq_general]) return obj_val - eq_val - ineq_val return lagrangian_hvp
__all__ = [ "build_exact_lagrangian_hvp", "build_lbfgs_lagrangian_hvp", ]