Source code for slsqp_jax.hessian

"""L-BFGS Hessian Approximation for SLSQP.

This module implements the Limited-memory BFGS (L-BFGS) algorithm for
maintaining a matrix-free approximation to the Hessian of the Lagrangian.

Instead of storing a dense n x n matrix (O(n^2) memory), L-BFGS stores
the last k (s, y) pairs and computes Hessian-vector products in O(kn) time
using the compact representation (Byrd, Nocedal, Schnabel 1994):

    B = B_0 - [B_0 S, Y] @ M^{-1} @ [S^T B_0; Y^T]

where B_0 = diag(diagonal) is the initial Hessian (per-variable scaling)
and M is a small 2k x 2k matrix built from inner products of the stored
vectors.  During normal operation ``diagonal = gamma * ones(n)`` and this
reduces to the scalar-scaled form.  After an SNOPT-style diagonal reset,
``diagonal`` captures per-variable curvature from the discarded history.

VARCHEN-style Powell damping (Lotfi et al., 2020) is applied to each
(s, y) pair before storage, damping toward B_0 = diag(diagonal) instead
of the full L-BFGS approximation B.  This is cheaper (O(n) vs O(k^2 n)),
always well-conditioned, and avoids the circular dependency where a
badly-conditioned B poisons its own damping.
"""

import equinox as eqx
import jax
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Callable
from jaxtyping import Array, Bool, Float, Int, jaxtyped

from slsqp_jax.types import Scalar, Vector


[docs] class LBFGSHistory(eqx.Module): """L-BFGS history buffer for matrix-free Hessian approximation. Stores the last k (s, y) pairs in a circular buffer and provides efficient Hessian-vector products via the compact representation. Attributes: s_history: Stored step vectors s_i = x_{i+1} - x_i. y_history: Stored (damped) gradient differences y_i. gamma: Scalar summary of the initial Hessian scaling. diagonal: Per-variable initial Hessian scaling (B_0 = diag(d)). During normal operation this equals ``gamma * ones(n)``. After an SNOPT-style reset it stores per-variable curvature. count: Number of valid pairs stored (0 to memory size). next_idx: Next write position in the circular buffer. eig_lower: Estimate of lambda_min(H) from VARCHEN Theorem 2. eig_upper: Estimate of lambda_max(H) from VARCHEN Theorem 2. """ s_history: Float[Array, "memory n"] y_history: Float[Array, "memory n"] gamma: Scalar diagonal: Float[Array, " n"] count: Int[Array, ""] next_idx: Int[Array, ""] eig_lower: Scalar eig_upper: Scalar
[docs] def lbfgs_init(n: int, memory: int) -> LBFGSHistory: """Initialize an empty L-BFGS history buffer. Args: n: Dimension of the parameter space. memory: Maximum number of (s, y) pairs to store (typically 5-20). Returns: An initialized LBFGSHistory with no stored pairs and gamma=1. """ return LBFGSHistory( # ty: ignore[invalid-return-type] # equinox @dataclass_transform s_history=jnp.zeros((memory, n)), y_history=jnp.zeros((memory, n)), gamma=jnp.array(1.0), diagonal=jnp.ones(n), count=jnp.array(0), next_idx=jnp.array(0), eig_lower=jnp.array(1.0), eig_upper=jnp.array(1.0), )
[docs] @jaxtyped(typechecker=beartype) def lbfgs_hvp( history: LBFGSHistory, v: Vector, ) -> Vector: """Compute B @ v using the L-BFGS compact representation. Uses the compact form with diagonal initial Hessian B_0 = diag(d) (Byrd, Nocedal & Schnabel, 1994, Theorem 2.2): B = B_0 - [B_0 S, Y] @ M^{-1} @ [S^T B_0; Y^T] where: M = [[S^T B_0 S, L], [L^T, -D_sy]] (2k x 2k) L_{ij} = s_i^T y_j for i > j (strictly lower triangular) D_sy = diag(s_i^T y_i) (diagonal) When no pairs are stored (count=0), this reduces to B = diag(d). Complexity: O(k^2 n) where k is the number of stored pairs. Args: history: L-BFGS history buffer. v: Vector to multiply by the Hessian approximation. Returns: B @ v, the Hessian-vector product. """ k = history.s_history.shape[0] d = history.diagonal count = history.count # Reorder to chronological order from the circular buffer start = (history.next_idx - count + k) % k indices = (start + jnp.arange(k)) % k S = history.s_history[indices] # (k, n) Y = history.y_history[indices] # (k, n) # Zero out invalid entries (positions >= count are not valid) valid_mask = (jnp.arange(k) < count)[:, None] # (k, 1) S = S * valid_mask Y = Y * valid_mask # Compute the diagonal part of the Hessian approximation # DS[i, :] = d * s_i (B_0 applied row-wise) DS = S * d[None, :] # (k, n) # Build compact form inner matrices SY = S @ Y.T # (k, k) SDSS = DS @ S.T # (k, k): S^T B_0 S L = jnp.tril(SY, k=-1) D_diag = jnp.diag(SY) invalid_diag = jnp.where(jnp.arange(k) < count, 0.0, 1.0) top_left = SDSS + jnp.diag(invalid_diag) top_right = L bottom_left = L.T bottom_right = -jnp.diag(D_diag) + jnp.diag(invalid_diag) top = jnp.concatenate([top_left, top_right], axis=1) bottom = jnp.concatenate([bottom_left, bottom_right], axis=1) M = jnp.concatenate([top, bottom], axis=0) # Small regularization for numerical stability M = M + 1e-10 * jnp.eye(2 * k) # p = [S^T B_0 v; Y^T v] = [DS @ v; Y @ v] but DS rows are d*s_i # so DS @ v would be wrong shape. We need S @ (d * v). dv = d * v p = jnp.concatenate([S @ dv, Y @ v]) q = jnp.linalg.solve(M, p) # B v = B_0 v - [B_0 S, Y]^T @ q = d*v - DS^T @ q[:k] - Y^T @ q[k:] result = dv - DS.T @ q[:k] - Y.T @ q[k:] # Guard against NaN/inf from numerical blowup in the compact form solve. # The magnitude threshold scales with the diagonal so that it is never # triggered by legitimate Hessian eigenvalues. B₀ = diag(d), so the # expected scale of ||Bv|| is O(max|d| · ||v||); using a 1000× margin # on top of that catches only genuine numerical failures. v_norm = jnp.linalg.norm(v) result_norm = jnp.linalg.norm(result) max_diag = jnp.max(jnp.abs(d)) threshold = 1000.0 * jnp.maximum(max_diag, 1.0) is_bad = jnp.any(~jnp.isfinite(result)) | ( result_norm > threshold * jnp.maximum(v_norm, 1e-10) ) result = jnp.where(is_bad, v, result) return result
[docs] @jaxtyped(typechecker=beartype) def lbfgs_inverse_hvp( history: LBFGSHistory, v: Vector, ) -> Vector: """Compute H @ v = B^{-1} @ v via the L-BFGS two-loop recursion. Implements Nocedal & Wright Algorithm 7.4 with diagonal initial scaling H_0 = diag(1/diagonal) instead of scalar (1/gamma) I. Complexity: O(kn) where k is the number of stored pairs. Args: history: L-BFGS history buffer. v: Vector to multiply by the inverse Hessian approximation. Returns: H @ v = B^{-1} @ v, the inverse Hessian-vector product. """ k = history.s_history.shape[0] d = history.diagonal count = history.count start = (history.next_idx - count + k) % k indices = (start + jnp.arange(k)) % k S = history.s_history[indices] # (k, n) chronological Y = history.y_history[indices] # (k, n) chronological valid_mask = (jnp.arange(k) < count)[:, None] S = S * valid_mask Y = Y * valid_mask sTy = jnp.sum(S * Y, axis=1) # (k,) pair_ok = (jnp.arange(k) < count) & (sTy > 1e-12) rho = jnp.where(pair_ok, 1.0 / sTy, 0.0) # Backward loop: q = v; for i = k-1,...,0: alpha_i = rho_i s_i^T q; q -= alpha_i y_i alphas_init = jnp.zeros(k) def backward_step(carry, idx): q, alphas = carry rev_idx = k - 1 - idx s_i = S[rev_idx] y_i = Y[rev_idx] rho_i = rho[rev_idx] is_valid = rev_idx < count alpha_i = jnp.where(is_valid, rho_i * jnp.dot(s_i, q), 0.0) q = q - alpha_i * y_i alphas = alphas.at[rev_idx].set(alpha_i) return (q, alphas), None (q, alphas), _ = jax.lax.scan(backward_step, (v, alphas_init), jnp.arange(k)) # Apply initial inverse Hessian: H_0 = diag(1/d) d_safe = jnp.maximum(d, 1e-30) r = q / d_safe # Forward loop: for i = 0,...,k-1: beta = rho_i y_i^T r; r += s_i (alpha_i - beta) def forward_step(r, idx): s_i = S[idx] y_i = Y[idx] rho_i = rho[idx] alpha_i = alphas[idx] is_valid = idx < count beta = jnp.where(is_valid, rho_i * jnp.dot(y_i, r), 0.0) r = r + s_i * (alpha_i - beta) return r, None r, _ = jax.lax.scan(forward_step, r, jnp.arange(k)) v_norm = jnp.linalg.norm(v) r_norm = jnp.linalg.norm(r) # Fall back to identity if the result is non-finite or unreasonably large. # H₀ = diag(1/d), so expected scale of ||H v|| is O(max(1/d) · ||v||). max_inv_diag = 1.0 / jnp.min(d_safe) threshold = 1000.0 * jnp.maximum(max_inv_diag, 1.0) is_bad = jnp.any(~jnp.isfinite(r)) | ( r_norm > threshold * jnp.maximum(v_norm, 1e-10) ) r = jnp.where(is_bad, v, r) return r
[docs] @jaxtyped(typechecker=beartype) def lbfgs_should_skip( s: Vector, y: Vector, skip_threshold: float = 1e-8, ) -> Bool[Array, ""]: """Return the scalar ``should_skip`` decision that ``lbfgs_append`` uses. Exposed so callers can *observe* the skip decision without having to diff fields on the returned history (which becomes meaningless once the circular buffer saturates, since ``count`` stays at ``memory`` from then on regardless of whether a real append happened). """ s_norm = jnp.linalg.norm(s) y_norm = jnp.linalg.norm(y) step_too_small = s_norm < skip_threshold grad_diff_too_small = y_norm < skip_threshold curvature_ratio = y_norm / jnp.maximum(s_norm, 1e-30) curvature_too_extreme = (curvature_ratio > 1e8) | (curvature_ratio < 1e-8) sTy_raw = jnp.dot(s, y) relative_curvature = jnp.abs(sTy_raw) / jnp.maximum(s_norm * y_norm, 1e-30) # Floor lowered from ``1e-8`` to ``1e-12`` (machine-precision floor # for ``float64``). After an identity reset on a near-KKT iterate # ``||s|| ~ rtol`` and ``||y|| = ||B s|| = ||s||`` (since ``B = I``), # so ``s.y / (||s|| ||y||)`` is dominated by floating-point # cancellation in the dot product and can dip below ``1e-8`` purely # from rounding even when the pair carries genuine curvature # information. ``1e-12`` is the smallest threshold where the pair # is truly *noise* rather than under-resolved curvature. curvature_too_small = (relative_curvature < 1e-12) & (sTy_raw > 0) has_bad_values = ~( jnp.isfinite(s_norm) & jnp.isfinite(y_norm) & jnp.isfinite(sTy_raw) ) return ( step_too_small | grad_diff_too_small | curvature_too_extreme | curvature_too_small | has_bad_values )
[docs] @jaxtyped(typechecker=beartype) def lbfgs_curvature_diagnostics( s: Vector, y: Vector, skip_threshold: float = 1e-8, ) -> tuple[Scalar, Scalar, Bool[Array, ""]]: """Return ``(s.y, relative_curvature, skipped)`` for the curvature pair. Exposes the same intermediate quantities ``lbfgs_should_skip`` uses so the verbose callback can surface them without re-implementing the skip predicate. ``relative_curvature = |s.y| / (||s|| * ||y||)`` is the cosine of the angle between ``s`` and ``y`` and the dominant contributor to the skip decision (a near-zero value indicates the curvature pair is at the floating-point noise floor and gets skipped by the ``relative_curvature < skip_threshold`` branch of ``lbfgs_should_skip``). Args: s: Step vector ``x_{k+1} - x_k``. y: Gradient difference ``\u2207L_{k+1} - \u2207L_k`` (Lagrangian gradient difference for SLSQP). skip_threshold: Threshold passed to ``lbfgs_should_skip``. Returns: ``(sty, relative_curvature, skipped)``, where ``sty = s.y`` is the raw curvature inner product, ``relative_curvature`` is the normalised version used in the skip predicate, and ``skipped`` is the same boolean ``lbfgs_should_skip`` returns. """ s_norm = jnp.linalg.norm(s) y_norm = jnp.linalg.norm(y) sty = jnp.dot(s, y) relative_curvature = jnp.abs(sty) / jnp.maximum(s_norm * y_norm, 1e-30) skipped = lbfgs_should_skip(s, y, skip_threshold=skip_threshold) return sty, relative_curvature, skipped
[docs] @jaxtyped(typechecker=beartype) def lbfgs_append( history: LBFGSHistory, s: Vector, y: Vector, damping_threshold: float = 0.2, skip_threshold: float = 1e-8, diag_floor: float = 1e-4, diag_ceil: float = 1e6, ) -> LBFGSHistory: """Append a new (s, y) pair to the L-BFGS history with Powell damping. Uses VARCHEN-style damping toward B0 (Lotfi et al., 2020, eq 7-8): damping is computed against the diagonal initial Hessian ``B_0 = diag(diagonal)`` instead of the full L-BFGS approximation B. This is O(n) instead of O(k^2 n), always well-conditioned (the diagonal is clipped to ``[diag_floor, diag_ceil]``), and avoids the circular dependency where a badly-conditioned B poisons its own damping. The damped gradient difference is: y_damped = theta * y + (1 - theta) * B0 s where theta in [0, 1] is chosen to satisfy: s^T y_damped >= threshold * s^T B0 s If ``||s||`` is too small or the curvature ratio is too extreme, the update is skipped entirely to avoid numerical issues. The ``curvature_ratio`` bracket is ``[1e-8, 1e8]`` (more permissive than the previous ``[1e-6, 1e6]``) so we no longer drop noisy-but- informative pairs on high-dimensional problems; these pairs instead flow through Powell damping. The ``relative_curvature < 1e-8`` skip is now gated on ``sᵀy > 0`` (positive curvature) so negative- curvature pairs are handled by damping rather than dropped. After appending, the scalar ``gamma`` is updated to ``y_damped^T y_damped / (y_damped^T s)`` and clipped to ``[diag_floor, diag_ceil]``. The per-variable diagonal is updated using a component-wise secant: ``d[i] = |y_damped[i] * s[i]| / (s[i]^2 + eps)`` and clipped to the same bracket. This gives ``d[i] ≈ |H_{ii}|`` for diagonal Hessians regardless of step direction, unlike the classical Shanno-Phua formula ``d[i] = y_i^2 / (y^T s)`` which produces ``d ∝ h_i^2`` for multi-component steps. The old ``1e-6`` floor on ``clip_lo`` allowed the inverse-Hessian diagonal ``H_0 = 1/d`` to balloon to ``1e6``, producing very large CG directions that the line search could not backtrack. Raising the floor to ``1e-4`` keeps the inverse diagonal bounded by ``1e4`` and dramatically reduces the "non-descent direction + chronic line- search backtracking" failure mode. Args: history: Current L-BFGS history. s: Step vector s = x_{k+1} - x_k. y: Gradient difference y = nabla L_{k+1} - nabla L_k. damping_threshold: Powell damping threshold (default 0.2). skip_threshold: Minimum step norm for update (default 1e-8). diag_floor: Minimum per-variable diagonal entry (default 1e-4). diag_ceil: Maximum per-variable diagonal entry (default 1e6). Returns: Updated L-BFGS history with the new pair appended. """ should_skip = lbfgs_should_skip(s, y, skip_threshold=skip_threshold) def do_append(): # VARCHEN-style damping toward B0 = diag(diagonal). # O(n) and always well-conditioned, unlike the full lbfgs_hvp. B0s = history.diagonal * s sTB0s = jnp.dot(s, B0s) sTy = jnp.dot(s, y) sTB0s_safe = jnp.maximum(sTB0s, 1e-12) use_damping = sTy < damping_threshold * sTB0s_safe theta = jax.lax.cond( use_damping, lambda: (1.0 - damping_threshold) * sTB0s_safe / (sTB0s_safe - sTy + 1e-12), lambda: jnp.array(1.0), ) theta = jnp.clip(theta, 0.0, 1.0) y_damped = theta * y + (1.0 - theta) * B0s # Update gamma (scalar summary of average Hessian eigenvalue). # γ = yᵀy / sᵀy (Byrd, Nocedal & Schnabel 1994, eq 3.6). yTy = jnp.dot(y_damped, y_damped) yTs = jnp.dot(y_damped, s) gamma_candidate = yTy / jnp.maximum(yTs, 1e-12) gamma_new = jax.lax.cond( (yTs > 1e-12) & jnp.isfinite(gamma_candidate), lambda: jnp.clip(gamma_candidate, diag_floor, diag_ceil), lambda: history.gamma, ) # Per-variable diagonal update (component-wise secant). # d[i] = |y_i * s_i| / (s_i^2 + eps) ≈ |H_{ii}| for diagonal # Hessians, regardless of the step direction. The classical # Shanno-Phua formula d[i] = y_i^2 / (y^T s) uses a single # scalar normalizer, which produces d ∝ h_i^2 for multi- # component steps on diagonal problems, severely under- # estimating curvature and causing 10-100x slowdowns. # # Clipping is relative to gamma (the scalar average curvature) # rather than absolute. On non-diagonal Hessians (e.g. # Rosenbrock) the component-wise estimates are noisy: some # variables see near-cancellation of off-diagonal terms while # others see reinforcement, producing condition numbers of # 1e7-1e8. Absolute clipping to [1e-2, 1e6] preserves this # extreme spread, triggering the VARCHEN kappa > 1e6 soft # reset every step and limiting L-BFGS to 1 pair. Gamma- # relative clipping keeps kappa(B0) <= 1e4, preventing # spurious resets while preserving genuine per-variable # curvature differences (critical for ill-conditioned # diagonal problems like WeightedQuad). s_sq = s**2 per_var_estimate = jnp.abs(y_damped * s) / jnp.maximum(s_sq, 1e-12) clip_lo = jnp.maximum(gamma_new * 1e-2, diag_floor) clip_hi = jnp.minimum(gamma_new * 1e2, diag_ceil) per_var_clipped = jnp.clip(per_var_estimate, clip_lo, clip_hi) has_signal = s_sq > 1e-20 new_diagonal = jnp.where( has_signal & jnp.isfinite(per_var_estimate), per_var_clipped, history.diagonal, ) # Write to circular buffer at next_idx k = history.s_history.shape[0] idx = history.next_idx new_s_history = history.s_history.at[idx].set(s) new_y_history = history.y_history.at[idx].set(y_damped) new_count = jnp.minimum(history.count + 1, jnp.array(k)) new_idx = (idx + 1) % k # Build temporary history for condition estimation, then construct # the final one with updated eigenvalue bounds. tmp = LBFGSHistory( s_history=new_s_history, y_history=new_y_history, gamma=gamma_new, diagonal=new_diagonal, count=new_count, next_idx=new_idx, eig_lower=history.eig_lower, eig_upper=history.eig_upper, ) eig_lo, eig_hi = lbfgs_estimate_condition(tmp) # ty: ignore[invalid-argument-type] # equinox @dataclass_transform return LBFGSHistory( s_history=new_s_history, y_history=new_y_history, gamma=gamma_new, diagonal=new_diagonal, count=new_count, next_idx=new_idx, eig_lower=eig_lo, eig_upper=eig_hi, ) def skip(): return history return jax.lax.cond(~should_skip, do_append, skip)
[docs] @jaxtyped(typechecker=beartype) def lbfgs_estimate_condition( history: LBFGSHistory, damping_threshold: float = 0.2, ) -> tuple[Scalar, Scalar]: """Estimate eigenvalue bounds of the inverse Hessian approximation H. Uses the diagonal ``B_0 = diag(d)`` as a proxy for the full L-BFGS condition number. The inverse Hessian eigenvalues are bounded by ``[1/max(d), 1/min(d)]``, giving ``kappa(H) = max(d) / min(d)``. This is a practical simplification of the VARCHEN Theorem 2 bounds (Lotfi et al., 2020). The full recursive bounds are theoretically tight but overly pessimistic in practice because they propagate worst-case Lipschitz estimates, leading to false soft-reset triggers on moderately conditioned problems. Since we damp toward B0 (not the full B), the diagonal condition number is the most relevant quantity. Returns: (lambda_min_est, lambda_max_est) bounding the eigenvalues of H. """ d = history.diagonal d_safe = jnp.maximum(d, 1e-30) lam_min = jnp.min(1.0 / d_safe) lam_max = jnp.max(1.0 / d_safe) return lam_min, lam_max
[docs] @jaxtyped(typechecker=beartype) def lbfgs_soft_reset( history: LBFGSHistory, ) -> LBFGSHistory: """VARCHEN-style soft reset: keep only the most recent (s, y) pair. When the estimated condition number of the inverse Hessian exceeds a threshold, this drops all but the newest curvature pair. This is less aggressive than :func:`lbfgs_reset` (which extracts the diagonal and drops everything) or :func:`lbfgs_identity_reset` (which restores ``B = I``), preserving the most relevant curvature information. Based on VARCHEN Algorithm 1, Step 7 (Lotfi et al., 2020). """ k, n = history.s_history.shape newest_idx = (history.next_idx - 1 + k) % k newest_s = history.s_history[newest_idx] newest_y = history.y_history[newest_idx] new_s_history = jnp.zeros((k, n)).at[0].set(newest_s) new_y_history = jnp.zeros((k, n)).at[0].set(newest_y) has_pairs = history.count > 0 new_count = jnp.where(has_pairs, jnp.array(1), jnp.array(0)) d_safe = jnp.maximum(history.diagonal, 1e-30) eig_lo = jnp.min(1.0 / d_safe) eig_hi = jnp.max(1.0 / d_safe) return LBFGSHistory( # ty: ignore[invalid-return-type] # equinox @dataclass_transform s_history=new_s_history, y_history=new_y_history, gamma=history.gamma, diagonal=history.diagonal, count=new_count, next_idx=jnp.where(has_pairs, jnp.array(1), jnp.array(0)), eig_lower=eig_lo, eig_upper=eig_hi, )
[docs] @jaxtyped(typechecker=beartype) def lbfgs_compute_diagonal( history: LBFGSHistory, ) -> Float[Array, " n"]: """Extract diag(B_k) from the L-BFGS compact representation. From the compact form ``B = B_0 - W M^{-1} W^T``, the diagonal is diag(B_k) = diagonal - diag(W M^{-1} W^T) where ``W = [B_0 S, Y]`` is ``(n, 2k)`` and ``M`` is the ``2k x 2k`` inner matrix. The correction term is computed in O(k^2 n) by forming ``Q = W M^{-1}`` and summing ``(Q * W)`` row-wise. This is used by :func:`lbfgs_reset` to implement the SNOPT diagonal reset strategy (Gill, Murray & Saunders, 2005, Section 3.3). """ k = history.s_history.shape[0] d = history.diagonal count = history.count start = (history.next_idx - count + k) % k indices = (start + jnp.arange(k)) % k S = history.s_history[indices] Y = history.y_history[indices] valid_mask = (jnp.arange(k) < count)[:, None] S = S * valid_mask Y = Y * valid_mask DS = S * d[None, :] # (k, n): B_0 applied row-wise SY = S @ Y.T SDSS = DS @ S.T L_mat = jnp.tril(SY, k=-1) D_diag = jnp.diag(SY) invalid_diag = jnp.where(jnp.arange(k) < count, 0.0, 1.0) top_left = SDSS + jnp.diag(invalid_diag) top_right = L_mat bottom_left = L_mat.T bottom_right = -jnp.diag(D_diag) + jnp.diag(invalid_diag) top = jnp.concatenate([top_left, top_right], axis=1) bottom = jnp.concatenate([bottom_left, bottom_right], axis=1) M = jnp.concatenate([top, bottom], axis=0) M = M + 1e-10 * jnp.eye(2 * k) M_inv = jnp.linalg.inv(M) # (2k, 2k) — tiny # W is (n, 2k): columns are [d*s_0, ..., d*s_{k-1}, y_0, ..., y_{k-1}] W = jnp.concatenate([DS.T, Y.T], axis=1) # (n, 2k) # Q = W M^{-1}, shape (n, 2k) Q = W @ M_inv # diag(W M^{-1} W^T) = row-wise sum of Q * W diag_correction = jnp.sum(Q * W, axis=1) # (n,) return d - diag_correction
[docs] @jaxtyped(typechecker=beartype) def lbfgs_reset( history: LBFGSHistory, diag_floor: float = 1e-4, diag_ceil: float = 1e6, ) -> LBFGSHistory: """SNOPT-style diagonal reset of the L-BFGS history. Extracts ``diag(B_k)`` from the current approximation, discards all stored ``(s, y)`` pairs, and restarts with ``B_0 = diag(diag(B_k))``. This preserves per-variable curvature information across the reset, preventing the "everything is flat" effect that occurs when the scalar ``gamma`` becomes very small. The extracted diagonal is clipped to ``[diag_floor, diag_ceil]`` to keep the reset point aligned with the clipping used inside :func:`lbfgs_append`; previous versions used a hardcoded ``[1e-2, 1e6]`` that diverged from the append-path floor. Based on the SNOPT limited-memory reset strategy (Gill, Murray & Saunders, *SIAM Review*, 47(1), 2005, Section 3.3). """ diag_B = lbfgs_compute_diagonal(history) diag_clipped = jnp.clip(diag_B, diag_floor, diag_ceil) # Ensure all values are finite; fall back to 1.0 otherwise diag_safe = jnp.where(jnp.isfinite(diag_clipped), diag_clipped, 1.0) gamma_new = jnp.median(diag_safe) d_safe = jnp.maximum(diag_safe, 1e-30) eig_lo = jnp.min(1.0 / d_safe) eig_hi = jnp.max(1.0 / d_safe) k, n = history.s_history.shape return LBFGSHistory( # ty: ignore[invalid-return-type] # equinox @dataclass_transform s_history=jnp.zeros((k, n)), y_history=jnp.zeros((k, n)), gamma=gamma_new, diagonal=diag_safe, count=jnp.array(0), next_idx=jnp.array(0), eig_lower=eig_lo, eig_upper=eig_hi, )
[docs] @jaxtyped(typechecker=beartype) def lbfgs_identity_reset( history: LBFGSHistory, ) -> LBFGSHistory: """Hard reset of L-BFGS history to the identity Hessian. Unlike :func:`lbfgs_reset` which preserves per-variable curvature, this discards everything and restarts with ``B_0 = I``. Used as an escalation when repeated SNOPT-style diagonal resets fail to break an ill-conditioning cycle (e.g. consecutive QP failures where the extracted diagonal perpetuates the same problematic scaling). """ k, n = history.s_history.shape return LBFGSHistory( # ty: ignore[invalid-return-type] # equinox @dataclass_transform s_history=jnp.zeros((k, n)), y_history=jnp.zeros((k, n)), gamma=jnp.array(1.0), diagonal=jnp.ones(n), count=jnp.array(0), next_idx=jnp.array(0), eig_lower=jnp.array(1.0), eig_upper=jnp.array(1.0), )
[docs] @jaxtyped(typechecker=beartype) def compute_partial_lagrangian_gradient( grad_f: Vector, eq_jac: Float[Array, "m_eq n"], multipliers_eq: Float[Array, " m_eq"], gen_jac: Float[Array, "m_gen n"], multipliers_gen: Float[Array, " m_gen"], ) -> Vector: """Compute the partial Lagrangian gradient (without the bound block). Returns ``nabla f(x) - J_eq(x)^T lambda - J_gen(x)^T mu_gen``, where ``J_gen`` is the Jacobian of the *general* (nonlinear) inequality constraints — i.e. the inequality block excluding the constant identity-style rows that come from box bounds. This is the core helper used by both ``compute_lagrangian_gradient`` (which adds the bound contribution on top) and the post-line-search NLP-level bound-multiplier recovery in ``solver.py``: by reading off the partial Lagrangian gradient at indices that are at a bound at ``x_{k+1}``, the recovery picks the bound multiplier that exactly zeros the corresponding component of the full Lagrangian gradient. Args: grad_f: Gradient of the objective function ``nabla f(x)``. eq_jac: Jacobian of equality constraints (m_eq x n). multipliers_eq: Lagrange multipliers for equality constraints. gen_jac: Jacobian of *general* inequality constraints (m_gen x n). Must exclude bound rows. multipliers_gen: Lagrange multipliers for general inequality constraints. Returns: Partial Lagrangian gradient with no bound contribution. """ grad_L = grad_f if eq_jac.shape[0] > 0: grad_L = grad_L - eq_jac.T @ multipliers_eq if gen_jac.shape[0] > 0: grad_L = grad_L - gen_jac.T @ multipliers_gen return grad_L
[docs] @jaxtyped(typechecker=beartype) def compute_lagrangian_gradient( grad_f: Vector, eq_jac: Float[Array, "m_eq n"], ineq_jac: Float[Array, "m_ineq n"], multipliers_eq: Float[Array, " m_eq"], multipliers_ineq: Float[Array, " m_ineq"], ) -> Vector: """Compute the gradient of the Lagrangian function. The Lagrangian is: L(x, lambda, mu) = f(x) - lambda^T c_eq(x) - mu^T c_ineq(x) Its gradient with respect to x is: nabla_x L = nabla f(x) - J_eq^T lambda - J_ineq^T mu ``ineq_jac`` and ``multipliers_ineq`` must include the bound block when bounds are present (the ineq layout used by ``solver.py`` is ``[general; lower_bound; upper_bound]``). Internally this is built on top of :func:`compute_partial_lagrangian_gradient` so the bound-multiplier recovery in the outer solver shares the same code path for the non-bound contribution. Args: grad_f: Gradient of objective function nabla f(x). eq_jac: Jacobian of equality constraints (m_eq x n). ineq_jac: Jacobian of inequality constraints (m_ineq x n), including bound rows. multipliers_eq: Lagrange multipliers for equality constraints. multipliers_ineq: Lagrange multipliers for inequality constraints, including bound multipliers. Returns: Gradient of Lagrangian nabla_x L. """ grad_L = grad_f if eq_jac.shape[0] > 0: grad_L = grad_L - eq_jac.T @ multipliers_eq if ineq_jac.shape[0] > 0: grad_L = grad_L - ineq_jac.T @ multipliers_ineq return grad_L
[docs] @jaxtyped(typechecker=beartype) def estimate_hessian_diagonal( hvp_fn: Callable[[Vector], Vector], n: int, key: jax.Array, n_probes: int = 20, ) -> Vector: """Estimate the diagonal of a matrix given only its HVP. Uses the Bekas-Kokiopoulou-Saad (2007) stochastic estimator: for Rademacher random vectors z (entries i.i.d. from {-1, +1}), E[z * (H z)] = diag(H). Averaging over ``n_probes`` samples gives an unbiased estimate with variance O(||off-diag||^2 / k). The estimate is accurate when the matrix is diagonally dominant, which is common for Hessians of smooth functions in the natural coordinate basis. Args: hvp_fn: Function v -> H @ v (Hessian-vector product). n: Dimension of the matrix. key: JAX PRNG key for generating random probes. n_probes: Number of Rademacher probes (default 20). More probes reduce variance at the cost of more HVP calls. Each probe costs one forward-over-reverse AD pass. Returns: Estimated diagonal of H, shape (n,). """ Z = 2.0 * jax.random.bernoulli(key, shape=(n_probes, n)).astype(jnp.float64) - 1.0 W = jax.vmap(hvp_fn)(Z) return jnp.mean(Z * W, axis=0)