slsqp_jax.utils

Small helpers shared across the package.

slsqp_jax.utils

slsqp_jax.utils.args_closure(fn, args)[source]
Return type:

Callable[[Array], Array]

Parameters:
slsqp_jax.utils.to_scalar(x)[source]

Coerce a (possibly non-0-d, size-1) array to a true 0-d scalar.

This guards the SLSQP internals against user objective functions that return e.g. shape (1,) instead of a true 0-d scalar. When the input already has size 1 (any shape), it is reshaped to (). Any other shape will fail at trace time with a clear shape error from JAX, which is the desired behaviour: the objective is contractually scalar-valued.

Using this at the boundaries (init, line search) prevents a (1,) shape from propagating into f_val / lagrangian_val and turning the boolean done predicate fed to jax.lax.cond in terminate into a non-scalar (which raises TypeError: Pred must be a scalar deep inside JAX rather than at the call site).

Return type:

Array

Parameters:

x (Array | int | float | bool)