Source code for slsqp_jax.merit

"""L1 Merit Function and Line Search for SLSQP.

This module implements the Han-Powell L1-exact penalty merit function
and backtracking line search used to globalize the SLSQP algorithm.

The merit function is:
    φ(x; ρ) = f(x) + ρ * (‖c_eq(x)‖_1 + ‖max(0, -c_ineq(x))‖_1)

where ρ is the penalty parameter, chosen large enough to ensure descent.
"""

from collections.abc import Callable
from typing import Any, NamedTuple, cast

import jax
import jax.numpy as jnp
import numpy as np_cpu
from beartype import beartype
from jaxtyping import Array, Bool, Float, Int, jaxtyped

from slsqp_jax.types import Scalar, Vector
from slsqp_jax.utils import to_scalar


[docs] class LineSearchResult(NamedTuple): """Result from the line search. Attributes: alpha: The step size found. f_val: Function value at new point. eq_val: Equality constraint values at new point. ineq_val: Inequality constraint values at new point. success: Whether the line search succeeded. n_evals: Number of function evaluations. """ alpha: Scalar f_val: Scalar eq_val: Float[Array, " m_eq"] ineq_val: Float[Array, " m_ineq"] success: Bool[Array, ""] n_evals: Int[Array, ""]
[docs] @jaxtyped(typechecker=beartype) def compute_merit( f_val: Scalar, eq_val: Float[Array, " m_eq"], ineq_val: Float[Array, " m_ineq"], penalty: Scalar, ) -> Scalar: """Compute the L1-exact penalty merit function value. The merit function is: φ(x; ρ) = f(x) + ρ * (‖c_eq(x)‖_1 + ‖max(0, -c_ineq(x))‖_1) Args: f_val: Objective function value f(x). eq_val: Equality constraint values c_eq(x). ineq_val: Inequality constraint values c_ineq(x). penalty: Penalty parameter ρ. Returns: Merit function value φ(x; ρ). """ # Equality constraint violation: sum of absolute values eq_violation = jnp.sum(jnp.abs(eq_val)) # Inequality constraint violation: sum of max(0, -c_ineq) # c_ineq >= 0 is required, so violation occurs when c_ineq < 0 ineq_violation = jnp.sum(jnp.maximum(0.0, -ineq_val)) return f_val + penalty * (eq_violation + ineq_violation)
[docs] @jaxtyped(typechecker=beartype) def update_penalty_parameter( current_penalty: Scalar, multipliers_eq: Float[Array, " m_eq"], multipliers_ineq: Float[Array, " m_ineq"], margin: float = 1.1, ) -> Scalar: """Update the penalty parameter based on Lagrange multipliers. The penalty should be larger than the maximum absolute multiplier to ensure the merit function provides a descent direction. ``ρ >= max(abs(λ_i), abs(μ_j)) + margin`` Args: current_penalty: Current penalty parameter. multipliers_eq: Lagrange multipliers for equality constraints. multipliers_ineq: Lagrange multipliers for inequality constraints. margin: Safety margin factor (default 1.1). Returns: Updated penalty parameter. """ # Find maximum absolute multiplier max_mult = jnp.array(0.0) # Check equality multipliers if multipliers_eq.shape[0] > 0: max_mult = jnp.maximum(max_mult, jnp.max(jnp.abs(multipliers_eq))) # Check inequality multipliers if multipliers_ineq.shape[0] > 0: max_mult = jnp.maximum(max_mult, jnp.max(jnp.abs(multipliers_ineq))) # Ensure penalty is at least margin times the max multiplier # Also ensure it never decreases new_penalty = jnp.maximum(current_penalty, margin * max_mult) # Minimum penalty of 1.0 new_penalty = jnp.maximum(new_penalty, 1.0) return new_penalty