Skip to content

Commit

Permalink
[Bug] Fix LTune utility functions (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
ephoris authored May 10, 2024
1 parent e057d12 commit d3eec0a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 15 deletions.
27 changes: 13 additions & 14 deletions endure/ltune/util/ltune_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from endure.lcm.util import eval_lcm_impl
from endure.lsm.cost import EndureCost
from endure.lsm.types import LSMBounds, LSMDesign, System, Policy, STR_POLICY_DICT
from endure.lsm.types import LSMBounds, LSMDesign, System, Policy
from endure.ltune.data.generator import LTuneDataGenerator
from endure.ltune.loss import LearnedCostModelLoss
import endure.lsm.solver as Solver
Expand All @@ -17,9 +17,8 @@ def __init__(
self,
config: dict[str, Any],
model: torch.nn.Module,
design_type: str = "Level",
design_type: Policy,
) -> None:
self.policy = STR_POLICY_DICT.get(design_type, Policy.KHybrid)
self.bounds = LSMBounds()
self.gen = LTuneDataGenerator(self.bounds)
self.loss = LearnedCostModelLoss(
Expand Down Expand Up @@ -93,14 +92,14 @@ def get_solver_nominal_design(
w: float,
**kwargs,
) -> Tuple[LSMDesign, SciOpt.OptimizeResult]:
if self.design_type == "QLSM":
solver = Solver.QLSMSolver(self.config)
elif self.design_type == "KLSM":
solver = Solver.KLSMSolver(self.config)
elif self.design_type == "YZLSM":
solver = Solver.YZLSMSolver(self.config)
else: # design_type == "Classic"
solver = Solver.ClassicSolver(self.config)
if self.design_type == Policy.QFixed:
solver = Solver.QLSMSolver(self.bounds)
elif self.design_type == Policy.KHybrid:
solver = Solver.KLSMSolver(self.bounds)
elif self.design_type == Policy.YZHybrid:
solver = Solver.YZLSMSolver(self.bounds)
else: # design_type == Policy.Classic
solver = Solver.ClassicSolver(self.bounds)

design, sol = solver.get_nominal_design(
system,
Expand All @@ -114,11 +113,11 @@ def get_solver_nominal_design(
return design, sol

def convert_ltune_output(self, output: Tensor):
if self.design_type == "QLSM":
if self.design_type == Policy.QFixed:
design = self._qlsm_convert(output)
elif self.design_type == "KLSM":
elif self.design_type == Policy.KHybrid:
design = self._klsm_convert(output)
else:
else: # self.design_type == Policy.Classic
design = self._classic_convert(output)

return design
Expand Down

0 comments on commit d3eec0a

Please sign in to comment.