Source code for slsqp_jax.inner.masking
"""Shared active-row masking + bound-fix helper for inner solvers.
The three projector-based inner solvers (``ProjectedCGCholesky``,
``ProjectedCGCraig``, ``MinresQLPSolver``) all start from the same
five-line preamble: mask ``A`` and ``b`` to the active rows, optionally
project away the bound-fixed columns, and build a working HVP and
effective gradient that hide the fixed coordinates from the iteration.
Before this module that preamble was copy-pasted in three places. This
module hosts the single shared implementation.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import NamedTuple
import jax.numpy as jnp
from jaxtyping import Array, Bool, Float
from slsqp_jax.types import Vector
[docs]
class ActiveSubproblem(NamedTuple):
"""Data carried by every projector-based inner solver after masking.
Attributes:
A_work: ``A`` with inactive rows zeroed and (when bound-fixing
is in effect) fixed columns zeroed.
b_work: ``b`` with inactive entries zeroed and the fixed-column
contribution ``A_masked @ d_fixed`` subtracted.
free_mask: Boolean mask of free variables (``ones(n)`` when no
bound fixing).
d_fixed: Fixed-variable values on bound-active coordinates
(zeros elsewhere; zeros everywhere when no bound fixing).
has_fixed: ``True`` iff any coordinate is bound-fixed.
hvp_work: Working-subspace HVP. Equals ``hvp_fn`` when no
bound-fixing; otherwise ``v -> _free * hvp_fn(_free * v)``
so the iteration only sees the free coordinates.
g_eff: Effective gradient. Equals ``g`` when no bound-fixing;
otherwise ``_free * (g + hvp_fn(d_fixed))`` to absorb the
fixed-column cross-coupling into the linear term.
"""
A_work: Float[Array, "m n"]
b_work: Float[Array, " m"]
free_mask: Bool[Array, " n"]
d_fixed: Vector
has_fixed: bool
hvp_work: Callable[[Vector], Vector]
g_eff: Vector
[docs]
def make_active_subproblem(
hvp_fn: Callable[[Vector], Vector],
g: Vector,
A: Float[Array, "m n"],
b: Float[Array, " m"],
active_mask: Bool[Array, " m"],
free_mask: Bool[Array, " n"] | None = None,
d_fixed: Vector | None = None,
) -> ActiveSubproblem:
"""Build the masked subproblem consumed by every projector-based solver.
Implements the shared preamble:
1. ``A_masked = A`` with inactive rows zeroed.
2. ``b_masked = b`` with inactive entries zeroed.
3. If bound-fixing is in effect (``free_mask`` and ``d_fixed`` both
provided), zero the fixed columns of ``A_masked`` and absorb the
fixed-column contribution into ``b_work``.
4. Build ``hvp_work`` and ``g_eff`` so the iteration only sees the
free coordinates.
See :class:`ActiveSubproblem` for the field semantics.
"""
n = A.shape[1]
has_fixed = free_mask is not None and d_fixed is not None
A_masked = jnp.where(active_mask[:, None], A, 0.0)
b_masked = jnp.where(active_mask, b, 0.0)
if has_fixed and free_mask is not None and d_fixed is not None:
A_work = A_masked * free_mask[None, :]
b_work = b_masked - A_masked @ d_fixed
else:
A_work = A_masked
b_work = b_masked
_free: Bool[Array, " n"] = (
free_mask if free_mask is not None else jnp.ones(n, dtype=bool)
)
_dfixed: Vector = d_fixed if d_fixed is not None else jnp.zeros(n)
if has_fixed:
def hvp_work(v: Vector) -> Vector:
return _free * hvp_fn(_free * v)
g_eff = _free * (g + hvp_fn(_dfixed))
else:
hvp_work = hvp_fn
g_eff = g
return ActiveSubproblem(
A_work=A_work,
b_work=b_work,
free_mask=_free,
d_fixed=_dfixed,
has_fixed=bool(has_fixed),
hvp_work=hvp_work,
g_eff=g_eff,
)
__all__ = ["ActiveSubproblem", "make_active_subproblem"]