Source code for slsqp_jax.slsqp.derivatives

"""Gradient / Jacobian / HVP closure factories.

Centralises the dispatch between user-supplied derivative callables
and the AD fallbacks so the SLSQP class can remain agnostic to which
side provided each derivative.

Two helpers are public:

* :func:`build_jacobian_impl` — returns a closure ``(y, args) -> J``
  for a single constraint slot, parameterised by the user-supplied
  ``user_jac`` (or ``None``) and the user-supplied constraint function.
* :func:`build_hvp_contrib_impl` — returns a closure
  ``(y, v, args, multipliers) -> Σ μ_i H_{c_i} v`` for a single
  constraint slot, parameterised analogously.

The single-constraint-slot helpers eliminate the EQ/INEQ duplication
in the legacy ``__check_init__``: each side calls the same factory
with its own slot's user-supplied callables and ``m_constraints``.
"""

from __future__ import annotations

from collections.abc import Callable

import jax
import jax.numpy as jnp

from slsqp_jax.utils import args_closure


[docs] def build_grad_impl(user_grad_fn: Callable | None) -> Callable: """Closure ``(fn, y, args) -> ∇f(y)`` dispatching to user / AD.""" if user_grad_fn is not None: ufn = user_grad_fn def grad_impl(fn, y, args): return ufn(y, args) return grad_impl def grad_impl(fn, y, args): return jax.grad(lambda x: fn(x, args)[0])(y) return grad_impl
[docs] def build_jacobian_impl( *, user_jac: Callable | None, constraint_fn: Callable | None, n_constraints: int, ) -> Callable: """Closure ``(y, args) -> J(y)`` for one constraint slot. Returns a zero-Jacobian closure when no constraint function is supplied or the slot is empty (``n_constraints == 0``). """ if constraint_fn is not None and n_constraints > 0: if user_jac is not None: ujac = user_jac def jac_impl(y, args): return ujac(y, args) return jac_impl cfn = constraint_fn def jac_impl(y, args): return jax.jacrev(args_closure(cfn, args))(y) return jac_impl def jac_impl(y, _args): return jnp.zeros((n_constraints, y.shape[0])) return jac_impl
[docs] def build_hvp_contrib_impl( *, user_hvp: Callable | None, constraint_fn: Callable | None, n_constraints: int, ) -> Callable: """Closure ``(y, v, args, μ) -> Σ μ_i H_{c_i}(y) v`` for one slot. Mirrors :func:`build_jacobian_impl`: dispatches between user-supplied ``user_hvp``, AD via ``jvp(grad(weighted))``, or a zero-vector fallback. """ if constraint_fn is not None and n_constraints > 0: if user_hvp is not None: uhvp = user_hvp def hvp_contrib(y, v, args, multipliers): return multipliers @ uhvp(y, v, args) return hvp_contrib cfn = constraint_fn def hvp_contrib(y, v, args, multipliers): def weighted(x): return jnp.dot(multipliers, cfn(x, args)) _, contrib = jax.jvp(jax.grad(weighted), (y,), (v,)) return contrib return hvp_contrib def hvp_contrib(_y, v, _args, _multipliers): return jnp.zeros_like(v) return hvp_contrib
[docs] def build_obj_hvp_impl( *, user_obj_hvp: Callable | None, use_exact_hvp_in_qp: bool, ) -> Callable | None: """Optional ``(fn, y, v, args) -> H_f v`` closure. Returns ``None`` when no exact HVP source is available *and* Newton-CG is disabled, mirroring the legacy ``self._obj_hvp_impl is None`` sentinel used in ``step()`` to decide whether to probe the Hessian for the L-BFGS secant pair. """ if user_obj_hvp is not None: uhvp = user_obj_hvp def obj_hvp_impl(_fn, y, v, args): return uhvp(y, v, args) return obj_hvp_impl if use_exact_hvp_in_qp: def obj_hvp_impl(fn, y, v, args): _, hvp_val = jax.jvp(jax.grad(lambda x: fn(x, args)[0]), (y,), (v,)) return hvp_val return obj_hvp_impl return None
__all__ = [ "build_grad_impl", "build_hvp_contrib_impl", "build_jacobian_impl", "build_obj_hvp_impl", ]