slsqp_jax.compat¶
SciPy-style entry point and constraint parser. See minimize_like_scipy for the recommended migration path from scipy.optimize.minimize(method='SLSQP').
slsqp_jax.compat
SciPy compatibility layer for SLSQP-JAX.
Provides utilities to convert SciPy-style constraint specifications
(dicts, LinearConstraint, NonlinearConstraint) into the function/Jacobian/HVP
signatures expected by the SLSQP solver, and a convenience
minimize_like_scipy entry point.
Non-standard NonlinearConstraint.hessp extension¶
scipy.optimize.NonlinearConstraint does not ship a hessp
attribute, does not accept one in __init__, and SciPy’s own
solvers never read one. SLSQP-JAX’s compat layer nevertheless honours
a user-attached hessp attribute on a NonlinearConstraint: if
the attribute is present and callable, it is used as the per-component
constraint Hessian-vector product with precedence over hess.
This is a deliberate, unorthodox extension. It exists so that users
can avoid forming a dense (n, n) constraint Hessian (which SciPy’s
hess(x, v) convention forces) when all SLSQP-JAX actually needs is
the HVP stack.
Expected signature:
hessp(x, p) -> Array of shape (m, n)
where x is the current iterate (shape (n,)), p is the
direction vector (shape (n,)), m is the number of components
of the constraint, and row i of the returned array equals
(d^2 c_i / dx^2)(x) @ p.
Usage pattern:
nlc = NonlinearConstraint(fun, lb, ub, jac=jac_fn)
nlc.hessp = my_hessp # non-standard; ignored by SciPy, consumed here
Precedence rules:
If
hesspis present and callable, it wins overhess.If
hesspis present but not callable (e.g. a sentinel string like"2-point"), it is ignored andhessis used if callable – identical to the existing behaviour forhess.Validation is limited to positional-parameter arity via
inspect.signature; shape/dtype mismatches surface as JAX errors on first use. Callables whose signature cannot be introspected (e.g. some C-level builtins) are accepted silently.
- class slsqp_jax.compat.ParsedConstraints[source]¶
Bases:
objectResult of converting SciPy-style constraints for use with SLSQP.
All fields map directly to the corresponding
SLSQPconstructor arguments.- eq_hvp_fn: Callable[[Float[Array, 'n'], Float[Array, 'n'], Any], Float[Array, 'm n']] | None = None¶
- ineq_hvp_fn: Callable[[Float[Array, 'n'], Float[Array, 'n'], Any], Float[Array, 'm n']] | None = None¶
- __init__(eq_constraint_fn=None, ineq_constraint_fn=None, n_eq_constraints=0, n_ineq_constraints=0, eq_jac_fn=None, ineq_jac_fn=None, eq_hvp_fn=None, ineq_hvp_fn=None)¶
- Parameters:
eq_constraint_fn (Callable[[Float[Array, 'n'], Any], Float[Array, 'm']] | None)
ineq_constraint_fn (Callable[[Float[Array, 'n'], Any], Float[Array, 'm']] | None)
n_eq_constraints (int)
n_ineq_constraints (int)
eq_jac_fn (Callable[[Float[Array, 'n'], Any], Float[Array, 'm n']] | None)
ineq_jac_fn (Callable[[Float[Array, 'n'], Any], Float[Array, 'm n']] | None)
eq_hvp_fn (Callable[[Float[Array, 'n'], Float[Array, 'n'], Any], Float[Array, 'm n']] | None)
ineq_hvp_fn (Callable[[Float[Array, 'n'], Float[Array, 'n'], Any], Float[Array, 'm n']] | None)
- Return type:
None
- slsqp_jax.compat.parse_constraints(constraints, x0)[source]¶
Convert SciPy-style constraints into SLSQP-JAX constraint functions.
- Parameters:
constraints (
dict|list|LinearConstraint|NonlinearConstraint|tuple) – Any form accepted byscipy.optimize.minimize: a dict, list of dicts,LinearConstraint,NonlinearConstraint, or a list/tuple mixing those types. An empty tuple/list means “no constraints”.x0 (
Array) – Initial guess – used to evaluate dict constraint functions once to determine their output size.
- Returns:
Dataclass whose fields map to
SLSQPconstructor arguments.- Return type:
- slsqp_jax.compat.minimize_like_scipy(fun, x0, args=(), *, jac=None, hessp=None, bounds=None, constraints=(), options=None, has_aux=False, throw=False, verbose=False, auto_scale=True, auto_scale_target_gradient=None, auto_scale_max_factor=None)[source]¶
Minimise a function using SLSQP with a SciPy-like interface.
This is a convenience wrapper that accepts SciPy-style arguments, converts them for the SLSQP solver, and delegates to
optimistix.minimise.- Parameters:
fun (
Callable) – Objective function. Signature(x, *args) -> scalaror, when has_aux isTrue,(x, *args) -> (scalar, aux).x0 (
Any) – Initial guess (array-like).args (
tuple, default:()) – Extra positional arguments forwarded to fun (unpacked).jac (
Callable|bool|None, default:None) – Gradient of fun. A callable(x, *args) -> arrayorTrueto indicate that fun returns(f, g)(or((f, g), aux)when has_aux is set).hessp (
Callable|None, default:None) – Hessian-vector product(x, p, *args) -> array.bounds (
Bounds|list|tuple|None, default:None) – Variable bounds –None,Bounds, or sequence of(min, max)pairs.constraints (
dict|list|LinearConstraint|NonlinearConstraint|tuple, default:()) – SciPy-style constraints (dict / list-of-dicts /LinearConstraint/NonlinearConstraint). ANonlinearConstraintmay carry a user-attachedhesspattribute (non-standard; not part of SciPy’s API) that, if callable, is used as the per-component constraint HVP with precedence overhess. See the module-level docstring for the full contract.options (
dict[str,Any] |None, default:None) –Solver options dict. The following keys are popped with the listed defaults (which match the
SLSQPconstructor defaults):rtol(1e-6) – relative tolerance for stationarity.atol(1e-6) – absolute tolerance for feasibility.max_stepsormaxiter(100) – maximum outer iterations.min_steps(1) – minimum iterations before convergence is allowed.lbfgs_memory(10) – number of L-BFGS pairs.line_search_max_steps(20) – backtracking steps.armijo_c1(1e-4) – Armijo sufficient decrease.qp_max_iter(100) – active-set iteration budget.qp_max_cg_iter(50) – CG iterations per QP step.
Any remaining keys are forwarded as
**kwargsto theSLSQPconstructor, so anySLSQPattribute can be set here (e.g.proximal_tau,proximal_mu_min,proximal_mu_max,use_preconditioner,adaptive_cg_tol,cg_regularization,stagnation_tol).has_aux (
bool, default:False) – IfTrue, fun returns(value, aux).throw (
bool, default:False) – Whether to raise on solver failure.verbose (
bool|Callable[...,None], default:False) – Passed to theSLSQPconstructor.False(default) for silent,Trueto print all diagnostics, or a custom callable. Whenauto_scaleis on the built-in printer is wrapped to show user-unit values forf/|c|/|grad_f|/|grad_L|/|d|; merit / rho / gamma / L-BFGS internals keep a(s)suffix on their label to flag scaled units. Seeslsqp_jax.wrap_verbose_for_scaling().auto_scale (
bool|str, default:True) –Automatic problem scaling at the initial point (gradient-based, IPOPT/KNITRO-style). On by default as of this release.
True(default) ->"uniform"(target_gradient=1.0, max_factor=1e3, uniform=True). A single shared scalars_cis applied to every constraint row (equality + general inequality) and a separates_fto the objective, both symmetrically clipped to[1/max_factor, max_factor]. Preserves inter-row magnitude ratios (the right default for budget-style problems where one constraint is intentionally orders of magnitude larger than the others); fully fixes the documented||J_eq|| >> ||grad_f||divergence cascade. Note thatatol_internal = s_c * atol_user(nomin(., 1.0)cap, so the feasibility tolerance handed to the inner solver can exceedatol_userwhens_c > 1)."balanced"->target_gradient=1.0, max_factor=1e3, uniform=False. The legacy per-row default. Each constraint row gets its own factor driving||grad c_i||_inf -> 1. Flattens inter-row magnitudes; opt-in when one row’s gradient is vastly out of band and that’s not a meaningful spread.False-> no wrapping (pre-feature behaviour)."knitro"->target=1.0, max_factor=1.0(strict shrink-only per-row; opt-in for users who want zero amplification)."ipopt"->target=100.0, max_factor=1.0(very conservative per-row; may not fix all cascades)."aggressive"->target=1.0, max_factor=1e6(per-row, pushes amplification to the noise-floor limit).
When scaling is applied,
sol.statscarries ascale_factorsentry plus_user-suffixed copies of the multiplier vectors and the Lagrangian gradient norm.atolis auto-compensated so the user-perceived feasibility tolerance is preserved (uniform mode does this viaatol_internal = s_c * atol_user; per-row modes viaatol_internal = atol_user * min(min(s_eq), min(s_ineq), 1.0)).auto_scale_target_gradient (
float|None, default:None) – Optional explicit override of the mode’starget_gradient. Underuniformmode this value is consumed by both thes_fderivation (against||grad_f||_inf) and thes_cderivation (against the cross-row maxmax_i ||grad c_i||_inf); under per-row modes it drives every row’s individual factor.auto_scale_max_factor (
float|None, default:None) – Optional explicit override of the mode’smax_factor. Underuniformmode the bound is symmetric so the scale factor lives in[1/max_factor, max_factor]and the value must satisfymax_factor >= 1.0(smaller raisesValueError;max_factor == 1.0is legal but emits aUserWarningbecause it disables scaling). Under per-row modes the bound is one-sided (s in [eps, max_factor]);max_factor == 1.0means shrink-only.
- Return type: