slsqp_jax.utils¶
Small helpers shared across the package.
slsqp_jax.utils
- slsqp_jax.utils.to_scalar(x)[source]¶
Coerce a (possibly non-0-d, size-1) array to a true 0-d scalar.
This guards the SLSQP internals against user objective functions that return e.g. shape
(1,)instead of a true 0-d scalar. When the input already has size 1 (any shape), it is reshaped to(). Any other shape will fail at trace time with a clear shape error from JAX, which is the desired behaviour: the objective is contractually scalar-valued.Using this at the boundaries (init, line search) prevents a
(1,)shape from propagating intof_val/lagrangian_valand turning the booleandonepredicate fed tojax.lax.condinterminateinto a non-scalar (which raisesTypeError: Pred must be a scalardeep inside JAX rather than at the call site).