Source code for slsqp_jax.slsqp.preconditioner
"""Preconditioner factories for the QP inner solver.
Two preconditioner types are offered, selected by
``PreconditionerConfig.type``:
* ``"lbfgs"`` — L-BFGS inverse Hessian via the two-loop recursion
(Algorithm 7.4 of Nocedal & Wright).
* ``"diagonal"`` — stochastic diagonal Hessian estimate
(Bekas, Kokiopoulou & Saad, 2007), probed each step with
``diagonal_n_probes`` Rademacher vectors.
When equality constraints are present *and* sSQP proximal stabilisation
is active (``ProximalConfig.tau > 0``), both options apply a Woodbury
correction to deliver ``B̃⁻¹`` for ``B̃ = B + (1/μ) Aᵀ A``. The
``mu × mu`` inner block is factored once per call.
These functions are pure factories: they take the precomputed
``LBFGSHistory`` / Lagrangian HVP and return a closure ``v -> M v``.
"""
from __future__ import annotations
from collections.abc import Callable
import jax
import jax.numpy as jnp
from slsqp_jax.hessian import (
LBFGSHistory,
estimate_hessian_diagonal,
lbfgs_inverse_hvp,
)
from slsqp_jax.types import Scalar, Vector
[docs]
def build_lbfgs_preconditioner(
*,
lbfgs_history: LBFGSHistory,
eq_jac: jnp.ndarray | None,
proximal_active: bool,
proximal_mu: Scalar | float,
) -> Callable[[Vector], Vector]:
"""L-BFGS inverse Hessian preconditioner.
With proximal stabilisation enabled and equality constraints
present, applies the Woodbury identity to deliver
``(B + (1/μ) Aᵀ A)⁻¹``. Otherwise returns a plain ``B⁻¹`` apply.
"""
if proximal_active and eq_jac is not None and eq_jac.shape[0] > 0:
A_eq = eq_jac
mu = proximal_mu
m_eq = A_eq.shape[0]
Hinv_AT = jax.vmap(
lambda a: lbfgs_inverse_hvp(lbfgs_history, a),
)(A_eq)
gram = Hinv_AT @ A_eq.T
inner = mu * jnp.eye(m_eq) + gram
inner_factor = jnp.linalg.cholesky(inner + 1e-10 * jnp.eye(m_eq))
def preconditioner(v: Vector) -> Vector:
Hinv_v = lbfgs_inverse_hvp(lbfgs_history, v)
A_Hinv_v = A_eq @ Hinv_v
w = jax.scipy.linalg.cho_solve((inner_factor, True), A_Hinv_v)
correction = Hinv_AT.T @ w
return Hinv_v - correction
return preconditioner
def preconditioner(v: Vector) -> Vector:
return lbfgs_inverse_hvp(lbfgs_history, v)
return preconditioner
[docs]
def build_diagonal_preconditioner(
*,
lagrangian_hvp_fn: Callable[[Vector], Vector],
n: int,
step_count: jnp.ndarray,
n_probes: int,
eq_jac: jnp.ndarray | None,
proximal_active: bool,
proximal_mu: Scalar | float,
) -> Callable[[Vector], Vector]:
"""Stochastic diagonal Hessian preconditioner (Bekas et al., 2007).
Estimates ``diag(H_L)`` via Rademacher probing of the exact
Lagrangian HVP, with a deterministic key derived from the step
count. Small / negative entries are clamped to a positive floor
so the preconditioner stays SPD.
"""
key = jax.random.fold_in(jax.random.PRNGKey(42), step_count)
diag_est = estimate_hessian_diagonal(lagrangian_hvp_fn, n, key, n_probes=n_probes)
abs_diag = jnp.abs(diag_est)
floor = jnp.maximum(1e-8, 1e-6 * jnp.median(abs_diag))
diag_safe = jnp.maximum(abs_diag, floor)
inv_diag = 1.0 / diag_safe
if proximal_active and eq_jac is not None and eq_jac.shape[0] > 0:
A_eq = eq_jac
mu = proximal_mu
m_eq = A_eq.shape[0]
Dinv_AT = (A_eq * inv_diag[None, :]).T
gram = A_eq @ Dinv_AT
inner = mu * jnp.eye(m_eq) + gram
inner_factor = jnp.linalg.cholesky(inner + 1e-10 * jnp.eye(m_eq))
def preconditioner(v: Vector) -> Vector:
Dinv_v = inv_diag * v
A_Dinv_v = A_eq @ Dinv_v
w = jax.scipy.linalg.cho_solve((inner_factor, True), A_Dinv_v)
correction = Dinv_AT @ w
return Dinv_v - correction
return preconditioner
def preconditioner(v: Vector) -> Vector:
return inv_diag * v
return preconditioner
__all__ = [
"build_diagonal_preconditioner",
"build_lbfgs_preconditioner",
]