Source code for slsqp_jax.slsqp.bounds

"""NLP-level box-bound machinery for the SLSQP outer loop.

These helpers handle the three points where bounds enter the *outer*
SQP iteration (the reduced-space QP-level bound-fixing loop lives in
:mod:`slsqp_jax.qp.bound_fixing` instead):

* :func:`clip_to_bounds` — defensively project an iterate onto the
  feasible box (used in ``init`` and after every line search).
* :func:`compute_bound_constraint_values` — evaluate ``c(x) = x − lb``
  / ``ub − x`` for the finite bound rows; the result is appended to
  the user inequality vector.
* :func:`build_bound_jacobian` — build the constant ``[I; −I]``
  Jacobian rows for the bound constraints.  Computed once during
  ``init`` and stored on ``SLSQPState``.
* :func:`recover_bound_multipliers` — post-line-search refresh of the
  bound multipliers from the *partial* Lagrangian gradient at
  ``x_{k+1}``.  See ``AGENTS.md`` for the rationale.

All functions are pure and free of solver-class state; they take the
precomputed ``lower_indices`` / ``upper_indices`` tuples and the
``bounds`` array directly so they can be reused outside :class:`SLSQP`
(e.g. by the eventual standalone QP solver).
"""

from __future__ import annotations

import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Float

from slsqp_jax.hessian import compute_partial_lagrangian_gradient
from slsqp_jax.types import Vector


[docs] def clip_to_bounds( y: Vector, bounds: Float[Array, "n 2"] | None, ) -> Vector: """Project ``y`` onto the box defined by ``bounds``. Returns ``y`` unchanged when ``bounds`` is ``None``. """ if bounds is None: return y return jnp.clip(y, bounds[:, 0], bounds[:, 1])
[docs] def compute_bound_constraint_values( y: Vector, bounds: Float[Array, "n 2"] | None, lower_indices: tuple[int, ...], upper_indices: tuple[int, ...], ) -> Float[Array, " m_bounds"]: """Compute bound constraint values ``c(x) >= 0`` for finite bounds. Returns the empty vector when no finite bounds are present. """ if bounds is None or (len(lower_indices) == 0 and len(upper_indices) == 0): return jnp.zeros((0,)) lower_idx = np.array(lower_indices) upper_idx = np.array(upper_indices) lower_vals = ( y[lower_idx] - bounds[lower_idx, 0] if len(lower_idx) > 0 else jnp.zeros((0,)) ) upper_vals = ( bounds[upper_idx, 1] - y[upper_idx] if len(upper_idx) > 0 else jnp.zeros((0,)) ) return jnp.concatenate([lower_vals, upper_vals])
[docs] def build_bound_jacobian( n: int, bounds: Float[Array, "n 2"] | None, lower_indices: tuple[int, ...], upper_indices: tuple[int, ...], ) -> Float[Array, "m_bounds n"]: """Constant Jacobian of the bound constraints. ``+I`` rows for lower bounds, ``−I`` rows for upper bounds. Empty matrix when no finite bounds are present. """ if bounds is None or (len(lower_indices) == 0 and len(upper_indices) == 0): return jnp.zeros((0, n)) lower_idx = np.array(lower_indices) upper_idx = np.array(upper_indices) identity = jnp.eye(n) J_lower = identity[lower_idx] if len(lower_idx) > 0 else jnp.zeros((0, n)) J_upper = -identity[upper_idx] if len(upper_idx) > 0 else jnp.zeros((0, n)) return jnp.concatenate([J_lower, J_upper], axis=0)
[docs] def recover_bound_multipliers( *, y_new: Vector, grad_new: Vector, eq_jac_new: Float[Array, "m_eq n"], ineq_jac_new: Float[Array, "m_ineq n"], mult_eq: Float[Array, " m_eq"], mult_ineq_general: Float[Array, " m_general"], bounds: Float[Array, "n 2"] | None, lower_indices: tuple[int, ...], upper_indices: tuple[int, ...], m_general: int, ) -> tuple[Float[Array, " n_lower"], Float[Array, " n_upper"]]: """Post-line-search NLP-level bound-multiplier refresh. Reads off the bound multiplier at the active bounds from the partial Lagrangian gradient at ``x_{k+1}``, with the sign convention ``μ_lower = +partial_grad_L`` / ``μ_upper = -partial_grad_L`` inherited from :func:`build_bound_jacobian` (``+I`` / ``−I``). Clamped to ``≥ 0`` for dual feasibility. See ``AGENTS.md`` for the full motivation; in short, the QP-level bound multipliers were recovered from ``B d + g − Aᵀ λ`` at ``x_k`` using the L-BFGS HVP and the QP active set, so they inherit an ``O(L-BFGS) + O(line-search) + O(active-set)`` error budget that on bound-heavy problems pins ``||∇L|| / |L|`` above ``rtol`` even at a true KKT point. Splicing the partial-gradient recovery zeros that residual exactly at the bound-active indices by construction. """ n_lower = len(lower_indices) n_upper = len(upper_indices) if bounds is None or (n_lower == 0 and n_upper == 0): return jnp.zeros((0,), dtype=y_new.dtype), jnp.zeros((0,), dtype=y_new.dtype) gen_jac_new = ( ineq_jac_new[:m_general] if m_general > 0 else jnp.zeros((0, y_new.shape[0]), dtype=y_new.dtype) ) partial_grad_L = compute_partial_lagrangian_gradient( grad_new, eq_jac_new, mult_eq, gen_jac_new, mult_ineq_general, ) # Per-variable active-bound tolerance: 1e-12 absolute floor + a # relative ``eps · (1 + |y_new|)`` term so the test still fires # for variables whose magnitude pushes 1e-12 below local fp # spacing. Variables that ``clip_to_bounds`` snapped to the bound # satisfy the test exactly; the relative term only kicks in for # variables the line search drove to within fp precision of a # bound without explicit clipping. # # The full-coordinate at-bound masks are computed via # :func:`slsqp_jax.slsqp.multipliers.compute_at_lower_mask` / # ``compute_at_upper_mask`` so this recovery and the LS multiplier # recovery in :mod:`slsqp_jax.slsqp.multipliers` share the # identical predicate. from slsqp_jax.slsqp.multipliers import ( compute_at_lower_mask, compute_at_upper_mask, ) if n_lower > 0: lower_idx = jnp.asarray(lower_indices, dtype=jnp.int32) at_lower_full = compute_at_lower_mask(y_new, bounds, lower_indices) at_lower = at_lower_full[lower_idx] mu_lower_corr = jnp.maximum( jnp.where(at_lower, partial_grad_L[lower_idx], 0.0), 0.0, ) else: mu_lower_corr = jnp.zeros((0,), dtype=y_new.dtype) if n_upper > 0: upper_idx = jnp.asarray(upper_indices, dtype=jnp.int32) at_upper_full = compute_at_upper_mask(y_new, bounds, upper_indices) at_upper = at_upper_full[upper_idx] mu_upper_corr = jnp.maximum( jnp.where(at_upper, -partial_grad_L[upper_idx], 0.0), 0.0, ) else: mu_upper_corr = jnp.zeros((0,), dtype=y_new.dtype) return mu_lower_corr, mu_upper_corr
__all__ = [ "build_bound_jacobian", "clip_to_bounds", "compute_bound_constraint_values", "recover_bound_multipliers", ]