diff --git a/endure/ltune/util/ltune_eval.py b/endure/ltune/util/ltune_eval.py index 1c6d9b0..7d343a1 100644 --- a/endure/ltune/util/ltune_eval.py +++ b/endure/ltune/util/ltune_eval.py @@ -6,7 +6,7 @@ from endure.lcm.util import eval_lcm_impl from endure.lsm.cost import EndureCost -from endure.lsm.types import LSMBounds, LSMDesign, System, Policy, STR_POLICY_DICT +from endure.lsm.types import LSMBounds, LSMDesign, System, Policy from endure.ltune.data.generator import LTuneDataGenerator from endure.ltune.loss import LearnedCostModelLoss import endure.lsm.solver as Solver @@ -17,9 +17,8 @@ def __init__( self, config: dict[str, Any], model: torch.nn.Module, - design_type: str = "Level", + design_type: Policy, ) -> None: - self.policy = STR_POLICY_DICT.get(design_type, Policy.KHybrid) self.bounds = LSMBounds() self.gen = LTuneDataGenerator(self.bounds) self.loss = LearnedCostModelLoss( @@ -93,14 +92,14 @@ def get_solver_nominal_design( w: float, **kwargs, ) -> Tuple[LSMDesign, SciOpt.OptimizeResult]: - if self.design_type == "QLSM": - solver = Solver.QLSMSolver(self.config) - elif self.design_type == "KLSM": - solver = Solver.KLSMSolver(self.config) - elif self.design_type == "YZLSM": - solver = Solver.YZLSMSolver(self.config) - else: # design_type == "Classic" - solver = Solver.ClassicSolver(self.config) + if self.design_type == Policy.QFixed: + solver = Solver.QLSMSolver(self.bounds) + elif self.design_type == Policy.KHybrid: + solver = Solver.KLSMSolver(self.bounds) + elif self.design_type == Policy.YZHybrid: + solver = Solver.YZLSMSolver(self.bounds) + else: # design_type == Policy.Classic + solver = Solver.ClassicSolver(self.bounds) design, sol = solver.get_nominal_design( system, @@ -114,11 +113,11 @@ def get_solver_nominal_design( return design, sol def convert_ltune_output(self, output: Tensor): - if self.design_type == "QLSM": + if self.design_type == Policy.QFixed: design = self._qlsm_convert(output) - elif self.design_type == "KLSM": + elif self.design_type == Policy.KHybrid: design = self._klsm_convert(output) - else: + else: # self.design_type == Policy.Classic design = self._classic_convert(output) return design diff --git a/notebook b/notebook index 1bd4688..67e0c90 160000 --- a/notebook +++ b/notebook @@ -1 +1 @@ -Subproject commit 1bd46884804d9403597a046eb751a190bc5eb5e4 +Subproject commit 67e0c908c5b0d29da357e26cbab520dd90583e21