slsqp_jax.inner.projected_cg

Shared projected-CG driver run_projected_pcg consumed by ProjectedCGCholesky and ProjectedCGCraig.

slsqp_jax.inner.projected_cg

Shared projected-PCG driver for Cholesky and CRAIG inner solvers.

Both ProjectedCGCholesky and ProjectedCGCraig historically held a copy of the same projected-PCG outer loop: build r0 = project(-(g_eff + B d_p)), prime z0 = project(precond(r0)), run build_cg_step() inside jax.lax.fori_loop(), then recover multipliers from the full HVP via the projection context. The only difference between the two was the constraint preconditioner construction, which each solver supplies externally.

This module hosts the shared driver run_projected_pcg(). Each concrete solver builds its own ProjectionContext and (optionally) its own constraint-preconditioner factory and delegates the rest here.

slsqp_jax.inner.projected_cg.run_projected_pcg(ctx, hvp_fn, g, max_cg_iter, cg_tol, effective_precond=None, cg_regularization=1e-06)[source]

Run projected (preconditioned) CG against an existing context.

Performs the body of _solve_projected_cg_* shared by both the Cholesky and CRAIG solvers:

  1. Initialise r0 = project(-(g_eff + B d_p)).

  2. Optionally precondition z0 = project(precond(r0)) (positive definiteness check; fall back to z0 = r0 on rz0 <= 0).

  3. Run build_cg_step() inside jax.lax.fori_loop.

  4. Recover Lagrange multipliers from the full unmasked HVP via ctx.recover_multipliers(B d + g).

Args:
ctx: Projection context built by the caller (Cholesky, CRAIG,

or a composed strategy). Provides A_work, project, hvp_work, g_eff, d_p.

hvp_fn: The full unmasked HVP, used only for multiplier

recovery. ctx.hvp_work is used inside the CG loop.

g: Full unmasked gradient (used for multiplier recovery). max_cg_iter: CG iteration budget. cg_tol: CG residual-norm tolerance. effective_precond: Already-wrapped preconditioner (e.g. from a

constraint preconditioner factory) or None.

cg_regularization: Curvature-guard threshold.

Returns:

(d, multipliers, cg_converged).

Return type:

tuple[Float[Array, 'n'], Float[Array, 'm'], Bool[Array, '']]

Parameters:
  • ctx (ProjectionContext)

  • hvp_fn (Callable[[Float[Array, 'n']], Float[Array, 'n']])

  • g (Float[Array, 'n'])

  • max_cg_iter (int)

  • cg_tol (Float[Array, ''] | float)

  • effective_precond (Callable[[Float[Array, 'n']], Float[Array, 'n']] | None)

  • cg_regularization (float)