"""SciPy compatibility layer for SLSQP-JAX.
Provides utilities to convert SciPy-style constraint specifications
(dicts, LinearConstraint, NonlinearConstraint) into the function/Jacobian/HVP
signatures expected by the SLSQP solver, and a convenience
``minimize_like_scipy`` entry point.
Non-standard ``NonlinearConstraint.hessp`` extension
----------------------------------------------------
``scipy.optimize.NonlinearConstraint`` does **not** ship a ``hessp``
attribute, does **not** accept one in ``__init__``, and SciPy's own
solvers never read one. SLSQP-JAX's compat layer nevertheless honours
a user-attached ``hessp`` attribute on a ``NonlinearConstraint``: if
the attribute is present and callable, it is used as the per-component
constraint Hessian-vector product with **precedence over** ``hess``.
This is a deliberate, unorthodox extension. It exists so that users
can avoid forming a dense ``(n, n)`` constraint Hessian (which SciPy's
``hess(x, v)`` convention forces) when all SLSQP-JAX actually needs is
the HVP stack.
Expected signature::
hessp(x, p) -> Array of shape (m, n)
where ``x`` is the current iterate (shape ``(n,)``), ``p`` is the
direction vector (shape ``(n,)``), ``m`` is the number of components
of the constraint, and row ``i`` of the returned array equals
``(d^2 c_i / dx^2)(x) @ p``.
Usage pattern::
nlc = NonlinearConstraint(fun, lb, ub, jac=jac_fn)
nlc.hessp = my_hessp # non-standard; ignored by SciPy, consumed here
Precedence rules:
1. If ``hessp`` is present and callable, it wins over ``hess``.
2. If ``hessp`` is present but not callable (e.g. a sentinel string
like ``"2-point"``), it is ignored and ``hess`` is used if
callable -- identical to the existing behaviour for ``hess``.
3. Validation is limited to positional-parameter arity via
``inspect.signature``; shape/dtype mismatches surface as JAX errors
on first use. Callables whose signature cannot be introspected
(e.g. some C-level builtins) are accepted silently.
"""
from __future__ import annotations
import inspect
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, Optional
import jax.numpy as jnp
import numpy as np
import optimistix as optx
from jaxtyping import Array, Float
from scipy.optimize import Bounds, LinearConstraint, NonlinearConstraint
from slsqp_jax.slsqp import SLSQP
from slsqp_jax.types import (
ConstraintFn,
ConstraintHVPFn,
GradFn,
HVPFn,
JacobianFn,
)
# ---------------------------------------------------------------------------
# Caching
# ---------------------------------------------------------------------------
class _CachedEvaluator:
"""Identity-based cache that deduplicates calls sharing the same ``x``.
Within a single SLSQP iteration the solver passes the *same* Python
object ``y`` to both ``eq_constraint_fn`` and ``ineq_constraint_fn``.
When a single underlying constraint source contributes to both groups,
wrapping it in a ``_CachedEvaluator`` avoids evaluating the source
function twice.
Only one entry is stored (the most recent). Across iterations ``x``
is a new object so the cache auto-invalidates.
"""
def __init__(self, fn: Callable) -> None:
self._fn = fn
self._cache_key: int | None = None
self._cache_val: Any = None
def __call__(self, x: Any, args: Any) -> Any:
key = id(x)
if key != self._cache_key:
self._cache_val = self._fn(x, args)
self._cache_key = key
return self._cache_val
class _CachedEvaluator2:
"""Like ``_CachedEvaluator`` but keyed on two positional args (x, v)."""
def __init__(self, fn: Callable) -> None:
self._fn = fn
self._cache_key: tuple[int, int] | None = None
self._cache_val: Any = None
def __call__(self, x: Any, v: Any, args: Any) -> Any:
key = (id(x), id(v))
if key != self._cache_key:
self._cache_val = self._fn(x, v, args)
self._cache_key = key
return self._cache_val
# ---------------------------------------------------------------------------
# Non-standard `NonlinearConstraint.hessp` validation
# ---------------------------------------------------------------------------
def _validate_hessp_signature(fn: Callable) -> None:
"""Arity check for a user-attached ``NonlinearConstraint.hessp``.
``scipy.optimize.NonlinearConstraint`` does not ship a ``hessp``
attribute; this project treats one, if present and callable, as a
per-component constraint HVP with signature ``hessp(x, p) -> (m, n)``.
Only the positional-parameter arity is verified. Callables whose
signature cannot be introspected (e.g. C-level builtins, some
``functools.partial`` wrappers on exotic targets) are accepted
silently and any shape/dtype errors surface on first use.
"""
try:
sig = inspect.signature(fn)
except (TypeError, ValueError):
return
positional = [
p
for p in sig.parameters.values()
if p.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
and p.default is inspect.Parameter.empty
]
has_varargs = any(
p.kind is inspect.Parameter.VAR_POSITIONAL for p in sig.parameters.values()
)
if has_varargs:
return
if len(positional) != 2:
raise TypeError(
"NonlinearConstraint.hessp must accept exactly two positional "
"arguments (x, p); got a callable with "
f"{len(positional)} required positional parameters. "
"Expected signature: hessp(x, p) -> Array of shape (m, n) whose "
"i-th row is (d^2 c_i / dx^2)(x) @ p."
)
# ---------------------------------------------------------------------------
# ParsedConstraints dataclass
# ---------------------------------------------------------------------------
[docs]
@dataclass
class ParsedConstraints:
"""Result of converting SciPy-style constraints for use with SLSQP.
All fields map directly to the corresponding ``SLSQP`` constructor
arguments.
"""
eq_constraint_fn: Optional[ConstraintFn] = None
ineq_constraint_fn: Optional[ConstraintFn] = None
n_eq_constraints: int = 0
n_ineq_constraints: int = 0
eq_jac_fn: Optional[JacobianFn] = None
ineq_jac_fn: Optional[JacobianFn] = None
eq_hvp_fn: Optional[ConstraintHVPFn] = None
ineq_hvp_fn: Optional[ConstraintHVPFn] = None
# ---------------------------------------------------------------------------
# Internal helpers to decompose a single constraint source
# ---------------------------------------------------------------------------
@dataclass
class _ConstraintParts:
"""Intermediate representation of a single SciPy constraint source.
``eq_fns`` and ``ineq_fns`` are lists of callables
``(x, args) -> Float[Array, " m_i"]``. Similarly for Jacobians
(``(x, args) -> Float[Array, "m_i n"]``) and HVPs
(``(x, v, args) -> Float[Array, "m_i n"]``).
"""
eq_fns: list[ConstraintFn] = field(default_factory=list)
ineq_fns: list[ConstraintFn] = field(default_factory=list)
n_eq: int = 0
n_ineq: int = 0
eq_jac_fns: list[JacobianFn | None] = field(default_factory=list)
ineq_jac_fns: list[JacobianFn | None] = field(default_factory=list)
eq_hvp_fns: list[ConstraintHVPFn | None] = field(default_factory=list)
ineq_hvp_fns: list[ConstraintHVPFn | None] = field(default_factory=list)
def _parse_dict_constraint(con: dict, x0: Array) -> _ConstraintParts:
"""Parse a single SciPy dict constraint."""
ctype = con["type"]
raw_fun = con["fun"]
raw_jac = con.get("jac", None)
extra_args = con.get("args", ())
def wrapped_fn(x: Any, args: Any) -> Float[Array, " m"]:
val = raw_fun(x, *extra_args)
return jnp.atleast_1d(jnp.asarray(val))
jac_fn: JacobianFn | None = None
if callable(raw_jac):
def jac_fn(x: Any, args: Any) -> Float[Array, "m n"]:
val = raw_jac(x, *extra_args)
return jnp.atleast_2d(jnp.asarray(val))
size = int(jnp.atleast_1d(jnp.asarray(raw_fun(x0, *extra_args))).shape[0])
parts = _ConstraintParts()
if ctype == "eq":
parts.eq_fns.append(wrapped_fn)
parts.n_eq = size
parts.eq_jac_fns.append(jac_fn)
parts.eq_hvp_fns.append(None)
elif ctype == "ineq":
parts.ineq_fns.append(wrapped_fn)
parts.n_ineq = size
parts.ineq_jac_fns.append(jac_fn)
parts.ineq_hvp_fns.append(None)
else:
raise ValueError(f"Unknown constraint type '{ctype}'; expected 'eq' or 'ineq'")
return parts
def _parse_linear_constraint(con: LinearConstraint) -> _ConstraintParts:
"""Parse a ``scipy.optimize.LinearConstraint``."""
A = jnp.asarray(np.atleast_2d(np.asarray(con.A, dtype=float)))
m = A.shape[0]
lb = jnp.broadcast_to(jnp.asarray(np.asarray(con.lb, dtype=float)), (m,))
ub = jnp.broadcast_to(jnp.asarray(np.asarray(con.ub, dtype=float)), (m,))
# We need to decide at Python time which indices are eq / ineq, so use
# concrete NumPy values.
lb_np = np.asarray(lb)
ub_np = np.asarray(ub)
eq_mask = lb_np == ub_np
has_lower = np.isfinite(lb_np) & ~eq_mask
has_upper = np.isfinite(ub_np) & ~eq_mask
parts = _ConstraintParts()
# Equality parts (lb == ub)
eq_indices = np.where(eq_mask)[0]
if len(eq_indices) > 0:
A_eq = A[eq_indices]
lb_eq = lb[eq_indices]
def eq_fn(x: Any, args: Any) -> Float[Array, " k"]:
return A_eq @ x - lb_eq
def eq_jac(x: Any, args: Any) -> Float[Array, "k n"]:
return A_eq
n_eq = len(eq_indices)
def eq_hvp(x: Any, v: Any, args: Any) -> Float[Array, "k n"]:
return jnp.zeros((n_eq, x.shape[0]))
parts.eq_fns.append(eq_fn)
parts.n_eq = n_eq
parts.eq_jac_fns.append(eq_jac)
parts.eq_hvp_fns.append(eq_hvp)
# Inequality parts: lower bound (A @ x - lb >= 0)
lower_indices = np.where(has_lower)[0]
upper_indices = np.where(has_upper)[0]
n_lower = len(lower_indices)
n_upper = len(upper_indices)
n_ineq = n_lower + n_upper
if n_ineq > 0:
A_lower = A[lower_indices] if n_lower > 0 else jnp.zeros((0, A.shape[1]))
lb_lower = lb[lower_indices] if n_lower > 0 else jnp.zeros((0,))
A_upper = A[upper_indices] if n_upper > 0 else jnp.zeros((0, A.shape[1]))
ub_upper = ub[upper_indices] if n_upper > 0 else jnp.zeros((0,))
def ineq_fn(x: Any, args: Any) -> Float[Array, " k"]:
lower_vals = A_lower @ x - lb_lower
upper_vals = ub_upper - A_upper @ x
return jnp.concatenate([lower_vals, upper_vals])
ineq_jac_matrix = jnp.concatenate([A_lower, -A_upper], axis=0)
def ineq_jac(x: Any, args: Any) -> Float[Array, "k n"]:
return ineq_jac_matrix
def ineq_hvp(x: Any, v: Any, args: Any) -> Float[Array, "k n"]:
return jnp.zeros((n_ineq, x.shape[0]))
parts.ineq_fns.append(ineq_fn)
parts.n_ineq = n_ineq
parts.ineq_jac_fns.append(ineq_jac)
parts.ineq_hvp_fns.append(ineq_hvp)
return parts
def _parse_nonlinear_constraint(con: NonlinearConstraint) -> _ConstraintParts:
"""Parse a ``scipy.optimize.NonlinearConstraint``.
In addition to the standard ``fun``, ``jac``, and ``hess``
attributes, this honours a non-standard ``hessp`` attribute (see
the module-level docstring for the full contract and rationale):
if ``con.hessp`` is callable it is used as the per-component
constraint HVP with precedence over ``con.hess``.
"""
raw_fun = con.fun
lb = np.atleast_1d(np.asarray(con.lb, dtype=float))
ub = np.atleast_1d(np.asarray(con.ub, dtype=float))
m = max(lb.shape[0], ub.shape[0])
lb = np.broadcast_to(lb, (m,))
ub = np.broadcast_to(ub, (m,))
lb_jnp = jnp.asarray(lb)
ub_jnp = jnp.asarray(ub)
eq_mask = lb == ub
has_lower = np.isfinite(lb) & ~eq_mask
has_upper = np.isfinite(ub) & ~eq_mask
eq_indices = np.where(eq_mask)[0]
lower_indices = np.where(has_lower)[0]
upper_indices = np.where(has_upper)[0]
needs_eq = len(eq_indices) > 0
needs_ineq = len(lower_indices) + len(upper_indices) > 0
# If a single source feeds both eq and ineq, share via cache
if needs_eq and needs_ineq:
def fun_closure(x, args):
return jnp.atleast_1d(jnp.asarray(raw_fun(x)))
cached_fn = _CachedEvaluator(fun_closure)
else:
cached_fn = None
# Jacobian / HVP availability
raw_jac = getattr(con, "jac", None)
has_jac = callable(raw_jac)
raw_hess = getattr(con, "hess", None)
has_hess = callable(raw_hess)
# Non-standard `hessp` extension: if the user has attached a callable
# `hessp` attribute to the NonlinearConstraint, it takes precedence over
# `hess`. See the module-level docstring for the full contract and
# rationale. Arity is validated up front; shape errors surface later.
_raw_hessp_attr = getattr(con, "hessp", None)
hessp_fn: Callable | None = _raw_hessp_attr if callable(_raw_hessp_attr) else None
if hessp_fn is not None:
_validate_hessp_signature(hessp_fn)
# Cached Jacobian evaluator (if callable)
if raw_jac is not None and has_jac and needs_eq and needs_ineq:
def jac_closure(x, args):
return jnp.atleast_2d(jnp.asarray(raw_jac(x)))
cached_jac = _CachedEvaluator(jac_closure)
else:
cached_jac = None
# Cached HVP evaluator: compute all m per-component HVPs once,
# then let eq_hvp_fn / ineq_hvp_fn select their rows.
#
# When `hessp` is supplied it already returns the (m, n) per-component
# stack in a single call, short-circuiting the m unit-vector loop that
# SciPy's `hess(x, v)` convention forces.
if hessp_fn is not None and needs_eq and needs_ineq:
hessp_call = hessp_fn
def _all_component_hvps(x: Any, v: Any, args: Any) -> Float[Array, "m n"]:
return jnp.atleast_2d(jnp.asarray(hessp_call(x, v)))
cached_hvp = _CachedEvaluator2(_all_component_hvps)
elif raw_hess is not None and has_hess and needs_eq and needs_ineq:
def _all_component_hvps(x: Any, v: Any, args: Any) -> Float[Array, "m n"]:
rows = []
for i in range(m):
e_i = jnp.zeros((m,)).at[i].set(1.0)
H_i = jnp.asarray(raw_hess(x, e_i))
rows.append(H_i @ v)
return jnp.stack(rows)
cached_hvp = _CachedEvaluator2(_all_component_hvps)
else:
cached_hvp = None
parts = _ConstraintParts()
# --- Equality portion ---
if needs_eq:
lb_eq = lb_jnp[eq_indices]
if cached_fn is not None:
def eq_fn(x: Any, args: Any) -> Float[Array, " k"]:
return cached_fn(x, args)[eq_indices] - lb_eq
else:
def eq_fn(x: Any, args: Any) -> Float[Array, " k"]:
return jnp.atleast_1d(jnp.asarray(raw_fun(x)))[eq_indices] - lb_eq
parts.eq_fns.append(eq_fn)
parts.n_eq = len(eq_indices)
# Jacobian
eq_jac_fn: JacobianFn | None = None
if raw_jac is not None and has_jac:
if cached_jac is not None:
def eq_jac_fn(x: Any, args: Any) -> Float[Array, "k n"]:
return cached_jac(x, args)[eq_indices]
else:
def eq_jac_fn(x: Any, args: Any) -> Float[Array, "k n"]:
return jnp.atleast_2d(jnp.asarray(raw_jac(x)))[eq_indices]
parts.eq_jac_fns.append(eq_jac_fn)
# HVP
eq_hvp_fn: ConstraintHVPFn | None = None
if hessp_fn is not None:
hessp_call_eq = hessp_fn
if cached_hvp is not None:
def eq_hvp_fn(x: Any, v: Any, args: Any) -> Float[Array, "k n"]:
return cached_hvp(x, v, args)[eq_indices]
else:
def eq_hvp_fn(x: Any, v: Any, args: Any) -> Float[Array, "k n"]:
return jnp.atleast_2d(jnp.asarray(hessp_call_eq(x, v)))[eq_indices]
elif raw_hess is not None and has_hess:
if cached_hvp is not None:
def eq_hvp_fn(x: Any, v: Any, args: Any) -> Float[Array, "k n"]:
return cached_hvp(x, v, args)[eq_indices]
else:
n_eq = len(eq_indices)
def eq_hvp_fn(x: Any, v: Any, args: Any) -> Float[Array, "k n"]:
rows = []
for i in range(n_eq):
e_i = jnp.zeros((m,)).at[eq_indices[i]].set(1.0)
H_i = jnp.asarray(raw_hess(x, e_i))
rows.append(H_i @ v)
return jnp.stack(rows)
parts.eq_hvp_fns.append(eq_hvp_fn)
# --- Inequality portion ---
if needs_ineq:
n_lower = len(lower_indices)
n_upper = len(upper_indices)
n_ineq = n_lower + n_upper
lb_lower = lb_jnp[lower_indices] if n_lower > 0 else jnp.zeros((0,))
ub_upper = ub_jnp[upper_indices] if n_upper > 0 else jnp.zeros((0,))
if cached_fn is not None:
def ineq_fn(x: Any, args: Any) -> Float[Array, " k"]:
vals = cached_fn(x, args)
lower_part = (
vals[lower_indices] - lb_lower if n_lower > 0 else jnp.zeros((0,))
)
upper_part = (
ub_upper - vals[upper_indices] if n_upper > 0 else jnp.zeros((0,))
)
return jnp.concatenate([lower_part, upper_part])
else:
def ineq_fn(x: Any, args: Any) -> Float[Array, " k"]:
vals = jnp.atleast_1d(jnp.asarray(raw_fun(x)))
lower_part = (
vals[lower_indices] - lb_lower if n_lower > 0 else jnp.zeros((0,))
)
upper_part = (
ub_upper - vals[upper_indices] if n_upper > 0 else jnp.zeros((0,))
)
return jnp.concatenate([lower_part, upper_part])
parts.ineq_fns.append(ineq_fn)
parts.n_ineq = n_ineq
# Jacobian
ineq_jac_fn: JacobianFn | None = None
if raw_jac is not None and has_jac:
if cached_jac is not None:
def ineq_jac_fn(x: Any, args: Any) -> Float[Array, "k n"]:
full_jac = cached_jac(x, args)
lower_jac = (
full_jac[lower_indices]
if n_lower > 0
else jnp.zeros((0, x.shape[0]))
)
upper_jac = (
-full_jac[upper_indices]
if n_upper > 0
else jnp.zeros((0, x.shape[0]))
)
return jnp.concatenate([lower_jac, upper_jac], axis=0)
else:
def ineq_jac_fn(x: Any, args: Any) -> Float[Array, "k n"]:
full_jac = jnp.atleast_2d(jnp.asarray(raw_jac(x)))
lower_jac = (
full_jac[lower_indices]
if n_lower > 0
else jnp.zeros((0, x.shape[0]))
)
upper_jac = (
-full_jac[upper_indices]
if n_upper > 0
else jnp.zeros((0, x.shape[0]))
)
return jnp.concatenate([lower_jac, upper_jac], axis=0)
parts.ineq_jac_fns.append(ineq_jac_fn)
# HVP
ineq_hvp_fn: ConstraintHVPFn | None = None
if hessp_fn is not None:
hessp_call_ineq = hessp_fn
if cached_hvp is not None:
def ineq_hvp_fn(x: Any, v: Any, args: Any) -> Float[Array, "k n"]:
all_hvps = cached_hvp(x, v, args)
lower_rows = (
all_hvps[lower_indices]
if n_lower > 0
else jnp.zeros((0, x.shape[0]))
)
upper_rows = (
-all_hvps[upper_indices]
if n_upper > 0
else jnp.zeros((0, x.shape[0]))
)
return jnp.concatenate([lower_rows, upper_rows])
else:
def ineq_hvp_fn(x: Any, v: Any, args: Any) -> Float[Array, "k n"]:
all_hvps = jnp.atleast_2d(jnp.asarray(hessp_call_ineq(x, v)))
lower_rows = (
all_hvps[lower_indices]
if n_lower > 0
else jnp.zeros((0, x.shape[0]))
)
upper_rows = (
-all_hvps[upper_indices]
if n_upper > 0
else jnp.zeros((0, x.shape[0]))
)
return jnp.concatenate([lower_rows, upper_rows])
elif raw_hess is not None and has_hess:
if cached_hvp is not None:
def ineq_hvp_fn(x: Any, v: Any, args: Any) -> Float[Array, "k n"]:
all_hvps = cached_hvp(x, v, args)
lower_rows = (
all_hvps[lower_indices]
if n_lower > 0
else jnp.zeros((0, x.shape[0]))
)
upper_rows = (
-all_hvps[upper_indices]
if n_upper > 0
else jnp.zeros((0, x.shape[0]))
)
return jnp.concatenate([lower_rows, upper_rows])
else:
def ineq_hvp_fn(x: Any, v: Any, args: Any) -> Float[Array, "k n"]:
rows = []
for idx in lower_indices:
e_i = jnp.zeros((m,)).at[idx].set(1.0)
H_i = jnp.asarray(raw_hess(x, e_i))
rows.append(H_i @ v)
for idx in upper_indices:
e_i = jnp.zeros((m,)).at[idx].set(1.0)
H_i = jnp.asarray(raw_hess(x, e_i))
rows.append(-(H_i @ v))
return jnp.stack(rows) if rows else jnp.zeros((0, x.shape[0]))
parts.ineq_hvp_fns.append(ineq_hvp_fn)
return parts
# ---------------------------------------------------------------------------
# Combining helpers
# ---------------------------------------------------------------------------
def _combine_fns(
fns: list[ConstraintFn],
) -> ConstraintFn:
"""Concatenate outputs of multiple constraint functions."""
if len(fns) == 1:
return fns[0]
def combined(x: Any, args: Any) -> Float[Array, " m"]:
return jnp.concatenate([f(x, args) for f in fns])
return combined
def _combine_jac_fns(
jac_fns: list[JacobianFn | None],
) -> JacobianFn | None:
"""Vertically stack Jacobians. Returns ``None`` if any entry is None."""
if any(j is None for j in jac_fns):
return None
fns = [j for j in jac_fns if j is not None] # for type narrowing
if len(fns) == 1:
return fns[0]
def combined(x: Any, args: Any) -> Float[Array, "m n"]:
return jnp.concatenate([f(x, args) for f in fns], axis=0)
return combined
def _combine_hvp_fns(
hvp_fns: list[ConstraintHVPFn | None],
) -> ConstraintHVPFn | None:
"""Vertically stack per-constraint HVP outputs. ``None`` if any is None."""
if any(h is None for h in hvp_fns):
return None
fns = [h for h in hvp_fns if h is not None]
if len(fns) == 1:
return fns[0]
def combined(x: Any, v: Any, args: Any) -> Float[Array, "m n"]:
return jnp.concatenate([f(x, v, args) for f in fns], axis=0)
return combined
# ---------------------------------------------------------------------------
# parse_constraints (public API)
# ---------------------------------------------------------------------------
[docs]
def parse_constraints(
constraints: dict | list | LinearConstraint | NonlinearConstraint | tuple,
x0: Array,
) -> ParsedConstraints:
"""Convert SciPy-style constraints into SLSQP-JAX constraint functions.
Parameters
----------
constraints
Any form accepted by ``scipy.optimize.minimize``:
a dict, list of dicts, ``LinearConstraint``, ``NonlinearConstraint``,
or a list/tuple mixing those types. An empty tuple/list means
"no constraints".
x0
Initial guess -- used to evaluate dict constraint functions once to
determine their output size.
Returns
-------
ParsedConstraints
Dataclass whose fields map to ``SLSQP`` constructor arguments.
"""
# Normalise to list
if isinstance(constraints, dict):
constraint_list: list = [constraints]
elif isinstance(constraints, (LinearConstraint, NonlinearConstraint)):
constraint_list = [constraints]
elif isinstance(constraints, (list, tuple)):
constraint_list = list(constraints)
else:
raise TypeError( # pragma: no cover
f"Unsupported constraints type {type(constraints)}. "
"Expected a dict, list, LinearConstraint, or NonlinearConstraint."
)
if len(constraint_list) == 0:
return ParsedConstraints()
# Collect parts from each source
all_parts: list[_ConstraintParts] = []
for con in constraint_list:
if isinstance(con, dict):
all_parts.append(_parse_dict_constraint(con, x0))
elif isinstance(con, LinearConstraint):
all_parts.append(_parse_linear_constraint(con))
elif isinstance(con, NonlinearConstraint):
all_parts.append(_parse_nonlinear_constraint(con))
else:
raise TypeError(
f"Unsupported constraint object type: {type(con)}"
) # pragma: no cover
# Merge all parts
eq_fns: list[ConstraintFn] = []
ineq_fns: list[ConstraintFn] = []
n_eq = 0
n_ineq = 0
eq_jac_fns: list[JacobianFn | None] = []
ineq_jac_fns: list[JacobianFn | None] = []
eq_hvp_fns: list[ConstraintHVPFn | None] = []
ineq_hvp_fns: list[ConstraintHVPFn | None] = []
for p in all_parts:
eq_fns.extend(p.eq_fns)
ineq_fns.extend(p.ineq_fns)
n_eq += p.n_eq
n_ineq += p.n_ineq
eq_jac_fns.extend(p.eq_jac_fns)
ineq_jac_fns.extend(p.ineq_jac_fns)
eq_hvp_fns.extend(p.eq_hvp_fns)
ineq_hvp_fns.extend(p.ineq_hvp_fns)
result = ParsedConstraints(n_eq_constraints=n_eq, n_ineq_constraints=n_ineq)
if eq_fns:
result.eq_constraint_fn = _combine_fns(eq_fns)
result.eq_jac_fn = _combine_jac_fns(eq_jac_fns)
result.eq_hvp_fn = _combine_hvp_fns(eq_hvp_fns)
if ineq_fns:
result.ineq_constraint_fn = _combine_fns(ineq_fns)
result.ineq_jac_fn = _combine_jac_fns(ineq_jac_fns)
result.ineq_hvp_fn = _combine_hvp_fns(ineq_hvp_fns)
return result
# ---------------------------------------------------------------------------
# Bounds conversion
# ---------------------------------------------------------------------------
def _convert_bounds(
bounds: Bounds | list | tuple | None,
n: int,
) -> Optional[Float[Array, "n 2"]]:
"""Convert SciPy-style bounds to the ``(n, 2)`` array used by SLSQP.
Parameters
----------
bounds
``None``, a ``scipy.optimize.Bounds`` instance, or a sequence of
``(min, max)`` pairs (with ``None`` meaning unbounded).
n
Number of variables (for validation).
Returns
-------
jax array of shape ``(n, 2)`` or ``None``.
"""
if bounds is None:
return None
if isinstance(bounds, Bounds):
lb = np.asarray(bounds.lb, dtype=float)
ub = np.asarray(bounds.ub, dtype=float)
lb = np.broadcast_to(lb, (n,))
ub = np.broadcast_to(ub, (n,))
arr = np.stack([lb, ub], axis=1)
# Replace np.nan with inf (Bounds uses nan for no-bound sometimes)
arr[np.isnan(arr[:, 0]), 0] = -np.inf
arr[np.isnan(arr[:, 1]), 1] = np.inf
return jnp.asarray(arr)
# Sequence of (min, max) pairs
bounds_list = list(bounds)
if len(bounds_list) != n:
raise ValueError(
f"bounds has {len(bounds_list)} entries but x0 has {n} elements"
)
arr = np.full((n, 2), [-np.inf, np.inf])
for i, (lo, hi) in enumerate(bounds_list):
if lo is not None:
arr[i, 0] = float(lo)
if hi is not None:
arr[i, 1] = float(hi)
return jnp.asarray(arr)
# ---------------------------------------------------------------------------
# minimize_like_scipy (public API)
# ---------------------------------------------------------------------------
[docs]
def minimize_like_scipy(
fun: Callable,
x0: Any,
args: tuple = (),
*,
jac: Callable | bool | None = None,
hessp: Callable | None = None,
bounds: Bounds | list | tuple | None = None,
constraints: dict | list | LinearConstraint | NonlinearConstraint | tuple = (),
options: dict[str, Any] | None = None,
has_aux: bool = False,
throw: bool = False,
verbose: bool | Callable[..., None] = False,
auto_scale: bool | str = True,
auto_scale_target_gradient: float | None = None,
auto_scale_max_factor: float | None = None,
) -> optx.Solution:
"""Minimise a function using SLSQP with a SciPy-like interface.
This is a convenience wrapper that accepts SciPy-style arguments,
converts them for the SLSQP solver, and delegates to
``optimistix.minimise``.
Parameters
----------
fun
Objective function. Signature ``(x, *args) -> scalar`` or, when
*has_aux* is ``True``, ``(x, *args) -> (scalar, aux)``.
x0
Initial guess (array-like).
args
Extra positional arguments forwarded to *fun* (unpacked).
jac
Gradient of *fun*. A callable ``(x, *args) -> array`` or ``True``
to indicate that *fun* returns ``(f, g)`` (or ``((f, g), aux)``
when *has_aux* is set).
hessp
Hessian-vector product ``(x, p, *args) -> array``.
bounds
Variable bounds -- ``None``, ``Bounds``, or sequence of
``(min, max)`` pairs.
constraints
SciPy-style constraints (dict / list-of-dicts /
``LinearConstraint`` / ``NonlinearConstraint``). A
``NonlinearConstraint`` may carry a user-attached ``hessp``
attribute (non-standard; not part of SciPy's API) that, if
callable, is used as the per-component constraint HVP with
precedence over ``hess``. See the module-level docstring for
the full contract.
options
Solver options dict. The following keys are popped with
the listed defaults (which match the ``SLSQP`` constructor
defaults):
* ``rtol`` (``1e-6``) -- relative tolerance for stationarity.
* ``atol`` (``1e-6``) -- absolute tolerance for feasibility.
* ``max_steps`` or ``maxiter`` (``100``) -- maximum outer
iterations.
* ``min_steps`` (``1``) -- minimum iterations before
convergence is allowed.
* ``lbfgs_memory`` (``10``) -- number of L-BFGS pairs.
* ``line_search_max_steps`` (``20``) -- backtracking steps.
* ``armijo_c1`` (``1e-4``) -- Armijo sufficient decrease.
* ``qp_max_iter`` (``100``) -- active-set iteration budget.
* ``qp_max_cg_iter`` (``50``) -- CG iterations per QP step.
Any remaining keys are forwarded as ``**kwargs`` to the
``SLSQP`` constructor, so any ``SLSQP`` attribute can be
set here (e.g. ``proximal_tau``, ``proximal_mu_min``,
``proximal_mu_max``, ``use_preconditioner``, ``adaptive_cg_tol``,
``cg_regularization``, ``stagnation_tol``).
has_aux
If ``True``, *fun* returns ``(value, aux)``.
throw
Whether to raise on solver failure.
verbose
Passed to the ``SLSQP`` constructor. ``False`` (default) for
silent, ``True`` to print all diagnostics, or a custom callable.
When ``auto_scale`` is on the built-in printer is wrapped to
show user-unit values for ``f`` / ``|c|`` / ``|grad_f|`` /
``|grad_L|`` / ``|d|``; merit / rho / gamma / L-BFGS internals
keep a ``(s)`` suffix on their label to flag scaled units.
See :func:`slsqp_jax.wrap_verbose_for_scaling`.
auto_scale
Automatic problem scaling at the initial point (gradient-based,
IPOPT/KNITRO-style). **On by default** as of this release.
* ``True`` (default) -> ``"uniform"`` (``target_gradient=1.0,
max_factor=1e3, uniform=True``). A single shared scalar
``s_c`` is applied to *every* constraint row (equality +
general inequality) and a separate ``s_f`` to the objective,
both **symmetrically** clipped to
``[1/max_factor, max_factor]``. Preserves inter-row
magnitude ratios (the right default for budget-style
problems where one constraint is intentionally orders of
magnitude larger than the others); fully fixes the
documented ``||J_eq|| >> ||grad_f||`` divergence cascade.
Note that ``atol_internal = s_c * atol_user`` (no
``min(., 1.0)`` cap, so the feasibility tolerance handed
to the inner solver can *exceed* ``atol_user`` when
``s_c > 1``).
* ``"balanced"`` -> ``target_gradient=1.0, max_factor=1e3,
uniform=False``. The legacy per-row default. Each
constraint row gets its own factor driving
``||grad c_i||_inf -> 1``. Flattens inter-row magnitudes;
opt-in when one row's gradient is *vastly* out of band and
that's not a meaningful spread.
* ``False`` -> no wrapping (pre-feature behaviour).
* ``"knitro"`` -> ``target=1.0, max_factor=1.0`` (strict
shrink-only per-row; opt-in for users who want zero
amplification).
* ``"ipopt"`` -> ``target=100.0, max_factor=1.0`` (very
conservative per-row; may not fix all cascades).
* ``"aggressive"`` -> ``target=1.0, max_factor=1e6``
(per-row, pushes amplification to the noise-floor limit).
When scaling is applied, ``sol.stats`` carries a
``scale_factors`` entry plus ``_user``-suffixed copies of the
multiplier vectors and the Lagrangian gradient norm. ``atol``
is auto-compensated so the user-perceived feasibility
tolerance is preserved (uniform mode does this via
``atol_internal = s_c * atol_user``; per-row modes via
``atol_internal = atol_user * min(min(s_eq), min(s_ineq),
1.0)``).
auto_scale_target_gradient
Optional explicit override of the mode's ``target_gradient``.
Under ``uniform`` mode this value is consumed by both the
``s_f`` derivation (against ``||grad_f||_inf``) and the
``s_c`` derivation (against the cross-row max
``max_i ||grad c_i||_inf``); under per-row modes it drives
every row's individual factor.
auto_scale_max_factor
Optional explicit override of the mode's ``max_factor``.
Under ``uniform`` mode the bound is **symmetric** so the
scale factor lives in ``[1/max_factor, max_factor]`` and
the value must satisfy ``max_factor >= 1.0`` (smaller
raises ``ValueError``; ``max_factor == 1.0`` is legal but
emits a ``UserWarning`` because it disables scaling).
Under per-row modes the bound is one-sided
(``s in [eps, max_factor]``); ``max_factor == 1.0`` means
shrink-only.
Returns
-------
optimistix.Solution
"""
opts = dict(options) if options is not None else {}
x0 = jnp.asarray(x0, dtype=float)
n = x0.shape[0]
# --- Parse constraints ---
parsed = parse_constraints(constraints, x0)
# --- Convert bounds ---
jax_bounds = _convert_bounds(bounds, n)
# --- Wrap objective ---
obj_grad_fn: GradFn | None = None
if jac is True:
# fun returns (f, g) or ((f, g), aux)
if has_aux: # pragma: no cover
def wrapped_fn(x: Any, packed_args: Any) -> tuple:
(f, g), aux = fun(x, *packed_args)
return jnp.asarray(f), aux
else:
def wrapped_fn(x: Any, packed_args: Any) -> tuple:
f, g = fun(x, *packed_args)
return jnp.asarray(f), None
# Extract gradient
if has_aux: # pragma: no cover
def obj_grad_fn(x: Any, packed_args: Any) -> Any:
(f, g), _aux = fun(x, *packed_args)
return jnp.asarray(g)
else:
def obj_grad_fn(x: Any, packed_args: Any) -> Any:
_f, g = fun(x, *packed_args)
return jnp.asarray(g)
else:
if has_aux:
def wrapped_fn(x: Any, packed_args: Any) -> tuple:
f, aux = fun(x, *packed_args)
return jnp.asarray(f), aux
else:
def wrapped_fn(x: Any, packed_args: Any) -> tuple:
return jnp.asarray(fun(x, *packed_args)), None
# --- Wrap jac (if callable) ---
if callable(jac):
user_jac = jac
def obj_grad_fn(x: Any, packed_args: Any) -> Any:
return jnp.asarray(user_jac(x, *packed_args))
# --- Wrap hessp ---
obj_hvp_fn: HVPFn | None = None
if callable(hessp):
user_hessp = hessp
def obj_hvp_fn(x: Any, v: Any, packed_args: Any) -> Any:
return jnp.asarray(user_hessp(x, v, *packed_args))
# --- Build solver ---
# Translate the SciPy-style flat options dictionary into the new
# nested :class:`SLSQPConfig`. This is the only documented
# migration path for downstream users that still want the
# SciPy-flavoured ``options=`` dict — direct construction of
# :class:`SLSQP` requires the nested config explicitly.
from slsqp_jax.config import (
AdaptiveCGConfig,
LBFGSConfig,
LineSearchConfig,
LPECAConfig,
PreconditionerConfig,
ProximalConfig,
QPConfig,
SLSQPConfig,
ToleranceConfig,
)
max_steps = opts.pop("max_steps", opts.pop("maxiter", 100))
def _pop_section(section_cls, mapping):
kw = {}
for src_key, dst_key in mapping.items():
if src_key in opts:
kw[dst_key] = opts.pop(src_key)
return section_cls(**kw)
tolerance = _pop_section(
ToleranceConfig,
{
"rtol": "rtol",
"atol": "atol",
"min_steps": "min_steps",
"stagnation_tol": "stagnation_tol",
"divergence_factor": "divergence_factor",
"divergence_patience": "divergence_patience",
},
)
tolerance = ToleranceConfig(
rtol=tolerance.rtol,
atol=tolerance.atol,
max_steps=max_steps,
min_steps=tolerance.min_steps,
stagnation_tol=tolerance.stagnation_tol,
divergence_factor=tolerance.divergence_factor,
divergence_patience=tolerance.divergence_patience,
)
lbfgs = _pop_section(
LBFGSConfig,
{
"lbfgs_memory": "memory",
"damping_threshold": "damping_threshold",
"lbfgs_diag_floor": "diag_floor",
"lbfgs_diag_ceil": "diag_ceil",
},
)
line_search = _pop_section(
LineSearchConfig,
{
"line_search_max_steps": "max_steps",
"armijo_c1": "armijo_c1",
"ls_failure_patience": "failure_patience",
},
)
qp = _pop_section(
QPConfig,
{
"qp_max_iter": "max_iter",
"qp_max_cg_iter": "max_cg_iter",
"qp_failure_patience": "failure_patience",
"zero_step_patience": "zero_step_patience",
"qp_ping_pong_threshold": "ping_pong_threshold",
"mult_drop_floor": "mult_drop_floor",
"cg_regularization": "cg_regularization",
"use_exact_hvp_in_qp": "use_exact_hvp",
},
)
proximal = _pop_section(
ProximalConfig,
{
"proximal_tau": "tau",
"proximal_mu_min": "mu_min",
"proximal_mu_max": "mu_max",
},
)
preconditioner = _pop_section(
PreconditionerConfig,
{
"use_preconditioner": "enabled",
"preconditioner_type": "type",
"diagonal_n_probes": "diagonal_n_probes",
},
)
lpeca = _pop_section(
LPECAConfig,
{
"active_set_method": "method",
"lpeca_sigma": "sigma",
"lpeca_beta": "beta",
"lpeca_use_lp": "use_lp",
"lpeca_trust_threshold": "trust_threshold",
"lpeca_warmup_steps": "warmup_steps",
"lpeca_predict_bounds": "predict_bounds",
},
)
adaptive_cg = _pop_section(
AdaptiveCGConfig,
{
"adaptive_cg_tol": "enabled",
"use_inexact_stationarity": "use_inexact_stationarity",
},
)
config = SLSQPConfig(
tolerance=tolerance,
lbfgs=lbfgs,
line_search=line_search,
qp=qp,
proximal=proximal,
preconditioner=preconditioner,
lpeca=lpeca,
adaptive_cg=adaptive_cg,
)
inner_solver = opts.pop("inner_solver", None)
if opts:
raise TypeError(
f"minimize_like_scipy: unrecognized option(s): {sorted(opts.keys())!r}"
)
# Auto-scaling: when on (default), wrap fn / constraints / Jacobians /
# HVPs by the gradient-based scale factors evaluated at x0, and
# compensate atol so the user-perceived feasibility tolerance is
# preserved. See slsqp_jax.scaling for the math + design rationale.
from slsqp_jax.scaling import (
auto_scaled_problem,
resolve_scaling_mode,
unscale_solution,
wrap_verbose_for_scaling,
)
scaling_cfg = resolve_scaling_mode(
auto_scale,
target_gradient=auto_scale_target_gradient,
max_factor=auto_scale_max_factor,
)
if scaling_cfg is not None:
# Snapshot the tolerance fields into local Python floats so the
# type checker can track them through the rebuild below.
from typing import cast as _cast
_tol = _cast(ToleranceConfig, tolerance)
prev_rtol = float(_tol.rtol)
prev_atol = float(_tol.atol)
prev_max_steps = int(_tol.max_steps)
prev_min_steps = int(_tol.min_steps)
prev_stag_tol = float(_tol.stagnation_tol)
prev_div_factor = float(_tol.divergence_factor)
prev_div_patience = int(_tol.divergence_patience)
scaled = auto_scaled_problem(
fn=wrapped_fn,
x0=x0,
args=args,
has_aux=True,
eq_constraint_fn=parsed.eq_constraint_fn,
ineq_constraint_fn=parsed.ineq_constraint_fn,
obj_grad_fn=obj_grad_fn,
eq_jac_fn=parsed.eq_jac_fn,
ineq_jac_fn=parsed.ineq_jac_fn,
obj_hvp_fn=obj_hvp_fn,
eq_hvp_fn=parsed.eq_hvp_fn,
ineq_hvp_fn=parsed.ineq_hvp_fn,
scaling_config=scaling_cfg,
atol_user=prev_atol,
)
# Override atol to atol_internal so the inner solver's
# convergence tests match the user-perceived feasibility
# tolerance. rtol is invariant to a uniform s_f rescaling
# so we leave it untouched.
tolerance = ToleranceConfig(
rtol=prev_rtol,
atol=scaled.factors.atol_internal,
max_steps=prev_max_steps,
min_steps=prev_min_steps,
stagnation_tol=prev_stag_tol,
divergence_factor=prev_div_factor,
divergence_patience=prev_div_patience,
)
# Pin the proximal ``mu_min`` floor to the *user's* pre-scaling
# ``atol`` rather than letting ``SLSQP.__init__`` resolve it
# against the (possibly much smaller) ``atol_internal``. Under
# auto-scaling ``atol_internal = atol_user * min(s_min, 1.0)``;
# without this override the proximal ``mu`` clip would float to
# the compensated tolerance, allowing ``(1/mu) A_eq^T A_eq`` to
# blow up by ``1 / min(s_min, 1.0)`` and ill-condition the
# proximal HVP. Only fire when the user did not specify
# ``mu_min`` explicitly (``None`` ≡ "resolve against atol").
if proximal.mu_min is None:
proximal_for_scaled = ProximalConfig(
tau=proximal.tau,
mu_min=prev_atol,
mu_max=proximal.mu_max,
)
else:
proximal_for_scaled = proximal
config = SLSQPConfig(
tolerance=tolerance,
lbfgs=lbfgs,
line_search=line_search,
qp=qp,
proximal=proximal_for_scaled,
preconditioner=preconditioner,
lpeca=lpeca,
adaptive_cg=adaptive_cg,
)
effective_fn = scaled.fn
effective_eq_fn = scaled.eq_constraint_fn
effective_ineq_fn = scaled.ineq_constraint_fn
effective_obj_grad = scaled.obj_grad_fn
effective_eq_jac = scaled.eq_jac_fn
effective_ineq_jac = scaled.ineq_jac_fn
effective_obj_hvp = scaled.obj_hvp_fn
effective_eq_hvp = scaled.eq_hvp_fn
effective_ineq_hvp = scaled.ineq_hvp_fn
# We pass verbose=False to SLSQP and patch in the scale-aware
# wrapper post-construction via ``eqx.tree_at``. This bypasses
# SLSQP's ``__check_init__`` ``_strip_fmt`` adapter which would
# otherwise truncate the (label, value, fmt) 3-tuples to
# 2-tuples and degrade the printed formatting.
effective_verbose = False
else:
scaled = None
effective_fn = wrapped_fn
effective_eq_fn = parsed.eq_constraint_fn
effective_ineq_fn = parsed.ineq_constraint_fn
effective_obj_grad = obj_grad_fn
effective_eq_jac = parsed.eq_jac_fn
effective_ineq_jac = parsed.ineq_jac_fn
effective_obj_hvp = obj_hvp_fn
effective_eq_hvp = parsed.eq_hvp_fn
effective_ineq_hvp = parsed.ineq_hvp_fn
effective_verbose = verbose
solver = SLSQP(
config=config,
eq_constraint_fn=effective_eq_fn,
ineq_constraint_fn=effective_ineq_fn,
n_eq_constraints=parsed.n_eq_constraints,
n_ineq_constraints=parsed.n_ineq_constraints,
bounds=jax_bounds,
obj_grad_fn=effective_obj_grad,
eq_jac_fn=effective_eq_jac,
ineq_jac_fn=effective_ineq_jac,
obj_hvp_fn=effective_obj_hvp,
eq_hvp_fn=effective_eq_hvp,
ineq_hvp_fn=effective_ineq_hvp,
inner_solver=inner_solver,
verbose=effective_verbose, # type: ignore[arg-type]
)
if scaled is not None:
# Post-construction patch: install the scale-aware verbose
# wrapper directly, bypassing ``__check_init__``'s
# ``_strip_fmt`` adapter (which would otherwise drop the
# ``.6e`` / ``.3e`` format specifiers). We do this even
# when ``verbose=False`` so the diagnostics layer can pick
# the factors off ``solver.verbose._slsqp_scale_factors``
# for its Auto-scaling section. ``verbose`` is a static
# ``eqx.field`` so neither ``eqx.tree_at`` nor
# ``dataclasses.replace`` can swap it without retriggering
# ``__check_init__``; ``object.__setattr__`` is the
# idiomatic escape hatch for static-field overrides.
wrapped_verbose = wrap_verbose_for_scaling(verbose, scaled.factors)
object.__setattr__(solver, "verbose", wrapped_verbose)
sol = optx.minimise(
effective_fn,
solver,
x0,
args=args,
has_aux=True,
max_steps=max_steps,
throw=throw,
)
if scaled is not None:
sol = unscale_solution(sol, scaled.factors)
return sol