Skip to content

Commit

Permalink
sensitivities in solution
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoAiraldi committed Nov 18, 2024
1 parent fdea7fd commit 281c201
Show file tree
Hide file tree
Showing 2 changed files with 387 additions and 3 deletions.
30 changes: 27 additions & 3 deletions src/csnlp/core/solutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ def barrier_parameter(self) -> float:
"""Gets the IPOPT barrier parameter at the optimal solution"""
return self.stats["iterations"]["mu"][-1]

@property
def sensitivities(self) -> dict[str, cs.DM]:
"""The sensititivity information from the solution, if any."""

def value(
self, expr: Union[SymType, np.ndarray], eval: bool = True
) -> Union[SymType, cs.DM]:
Expand Down Expand Up @@ -359,7 +363,7 @@ def __init__(
dual_vars: dict[str, SymType],
dual_vals: dict[str, cs.DM],
stats: dict[str, _Any],
solver_plugin: str,
sensitivities: Optional[dict[str, cs.DM]] = None,
) -> None:
self._f = f

Expand All @@ -378,7 +382,9 @@ def __init__(
self._dual_vals = dual_vals

self._stats = stats
self._solver_plugin = solver_plugin
if sensitivities is None:
sensitivities = {}
self._sensitivities = sensitivities

@property
def f(self) -> float:
Expand Down Expand Up @@ -424,6 +430,10 @@ def vals(self) -> dict[str, cs.DM]:
def dual_vals(self) -> dict[str, cs.DM]:
return self._dual_vals

@property
def sensitivities(self) -> dict[str, cs.DM]:
return self._sensitivities

@staticmethod
def from_casadi_solution(
sol_with_stats: dict[str, _Any], nlp: "Nlp[SymType]"
Expand Down Expand Up @@ -459,6 +469,11 @@ def from_casadi_solution(
else:
raise RuntimeError(f"unknown dual variable type {n}")

sensitivities = {
k: sol[k]
for k in sol
if k.startswith("grad_") or k.startswith("jac_") or k.startswith("hess_")
}
return EagerSolution(
f,
nlp._p,
Expand All @@ -474,7 +489,7 @@ def from_casadi_solution(
dual_vars,
dual_vals,
stats,
nlp.unwrapped._solver_plugin,
sensitivities,
)


Expand Down Expand Up @@ -599,6 +614,15 @@ def dual_vals(self) -> dict[str, cs.DM]:
raise RuntimeError(f"unknown dual variable type `{n}`")
return dual_vals

@_cached_property
def sensitivities(self) -> dict[str, cs.DM]:
sol = self._sol
return {
k: sol[k]
for k in sol
if k.startswith("grad_") or k.startswith("jac_") or k.startswith("hess_")
}

@staticmethod
def from_casadi_solution(
sol_with_stats: dict[str, _Any], nlp: "Nlp[SymType]"
Expand Down
Loading

0 comments on commit 281c201

Please sign in to comment.