Source code for slsqp_jax.utils
from collections.abc import Callable
from typing import TypeVar, Union
import jax
import jax.numpy as jnp
T = TypeVar("T")
[docs]
def args_closure(
fn: Callable[[jax.Array, T], jax.Array], args: T
) -> Callable[[jax.Array], jax.Array]:
def wrapped(x: jax.Array) -> jax.Array:
return fn(x, args)
return wrapped
[docs]
def to_scalar(x: Union[jax.Array, int, float, bool]) -> jax.Array:
"""Coerce a (possibly non-0-d, size-1) array to a true 0-d scalar.
This guards the SLSQP internals against user objective functions that
return e.g. shape ``(1,)`` instead of a true 0-d scalar. When the input
already has size 1 (any shape), it is reshaped to ``()``. Any other
shape will fail at trace time with a clear shape error from JAX, which
is the desired behaviour: the objective is contractually scalar-valued.
Using this at the boundaries (init, line search) prevents a ``(1,)``
shape from propagating into ``f_val`` / ``lagrangian_val`` and turning
the boolean ``done`` predicate fed to ``jax.lax.cond`` in
``terminate`` into a non-scalar (which raises ``TypeError: Pred must
be a scalar`` deep inside JAX rather than at the call site).
"""
return jnp.asarray(x).reshape(())