Skip to content

Commit

Permalink
[Feat] Change callback functions to return state
Browse files Browse the repository at this point in the history
  • Loading branch information
ephoris committed Aug 29, 2024
1 parent d251916 commit 7de96cb
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 23 deletions.
4 changes: 2 additions & 2 deletions endure.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, config: dict[str, Any]) -> None:
format=config["log"]["format"], datefmt=config["log"]["datefmt"]
)
self.log: logging.Logger = logging.getLogger(config["log"]["name"])
self.log.setLevel(logging.getLevelName(config["log"]["level"]))
self.log.setLevel(getattr(logging, config["log"]["level"]))
log_level = logging.getLevelName(self.log.getEffectiveLevel())
self.log.debug(f"Log level: {log_level}")

Expand All @@ -40,7 +40,7 @@ def run(self):
for job_name in jobs_list:
job = jobs.get(job_name, None)
if job is None:
self.log.warn(f"No job associated with {job_name}")
self.log.warning(f"No job associated with {job_name}")
continue
job = job(config)
_ = job.run()
Expand Down
4 changes: 2 additions & 2 deletions endure/util/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
model_test_kwargs: dict[str, Any] = {},
disable_tqdm: bool = False,
no_checkpoint: bool = False,
train_callback: Optional[Callable[[dict], None]] = None,
train_callback: Optional[Callable[[dict], dict]] = None,
) -> None:
self.log = log
self.model = model
Expand Down Expand Up @@ -95,7 +95,7 @@ def _train_loop(self) -> float:
self.scheduler.step()

if self.train_callback is not None:
self.train_callback(self.model_train_kwargs)
self.model_train_kwargs = self.train_callback(self.model_train_kwargs)

if self.train_len == 0:
self.train_len = batch + 1
Expand Down
12 changes: 5 additions & 7 deletions jobs/ltune_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,28 +120,26 @@ def gumbel_temp_schedule(
train_kwargs: dict,
decay_rate: float = 0.95,
floor: float = 0.01,
) -> None:
) -> dict:
train_kwargs["temp"] *= decay_rate
if train_kwargs["temp"] < floor:
train_kwargs["temp"] = floor

return
return train_kwargs

@staticmethod
def reinmax_temp_schedule(
train_kwargs: dict,
decay_rate: float = 0.9,
floor: float = 1,
) -> None:
) -> dict:
train_kwargs["temp"] *= decay_rate
if train_kwargs["temp"] < floor:
train_kwargs["temp"] = floor

return
return train_kwargs

def get_train_callback(self) -> Optional[Callable[[dict], None]]:
if not self.design == Policy.KHybrid:
return None
def get_train_callback(self) -> Optional[Callable[[dict], dict]]:
if self.config["ltune"]["model"]["categorical_mode"] == "reinmax":
return lambda train_kwargs: self.reinmax_temp_schedule(train_kwargs)
# default train_callback will be gumbel softmax
Expand Down
36 changes: 24 additions & 12 deletions jobs/mlos_exp_runs.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from endure.lsm.types import LSMBounds, LSMDesign, Policy, System, Workload
from mlos_core.optimizers import SmacOptimizer

NUM_SAMPLES = 100
NUM_ROUNDS = 20
NUM_ROUNDS = 100
NUM_TRIALS = 10


class ExperimentMLOS:
Expand Down Expand Up @@ -64,7 +64,8 @@ def _create_optimizer(self, parameter_space: CS.ConfigurationSpace):

def _train_model(
self,
workload_id: int,
wl_id: int,
trial: int,
workload: Workload,
system: System,
num_rounds: int = NUM_ROUNDS,
Expand All @@ -81,20 +82,26 @@ def _train_model(
optimizer.register(
configs=suggestion, scores=pd.DataFrame([{"cost": cost}])
)
self.log.info(f"Round {round}: Cost: {cost}")
self.db.log_round(workload_id, round, design, cost)
self.log.info(f"[ID {wl_id}][Trial {trial}][Round {round}] Cost: {cost}")
self.db.log_round(wl_id, trial, round, design, cost)

return

def run(self) -> None:
for _ in range(NUM_SAMPLES):
workload = Workload(*self.gen._sample_workload(4))
system = self.gen._sample_system()
system = System()
for rep_wl in self.config["workloads"]:
workload = Workload(
z0=rep_wl["z0"],
z1=rep_wl["z1"],
q=rep_wl["q"],
w=rep_wl["w"],
)
row_id = self.db.log_workload(workload, system)
self.log.info(f"Workload: {workload}")
self.log.info(f"System: {system}")
self.log.info(f"Environment ID: {row_id}")
self._train_model(row_id, workload, system)
for trial in range(NUM_TRIALS):
self.log.info(f"(Workload ID, Trial): ({row_id}, {trial})")
self._train_model(row_id, trial, workload, system)

return

Expand Down Expand Up @@ -127,6 +134,7 @@ def __init__(self, config: dict, db_path: str = "mlos_exp.db") -> None:
CREATE TABLE IF NOT EXISTS tunings (
idx INTEGER PRIMARY KEY AUTOINCREMENT,
env_id INTEGER,
trial INTEGER,
round INTEGER,
bits_per_elem REAL,
size_ratio INTEGER,
Expand Down Expand Up @@ -182,6 +190,7 @@ def log_workload(self, workload: Workload, system: System) -> int:
def log_round(
self,
workload_id: int,
trial: int,
round: int,
design: LSMDesign,
cost: float,
Expand All @@ -191,6 +200,7 @@ def log_round(
"""
INSERT INTO tunings (
env_id,
trial,
round,
bits_per_elem,
size_ratio,
Expand All @@ -200,9 +210,11 @@ def log_round(
kap15, kap16, kap17, kap18, kap19,
cost
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?,
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(workload_id, round, design.h, int(design.T)) + tuple(design.K) + (cost,),
(workload_id, trial, round, design.h, int(design.T))
+ tuple(design.K)
+ (cost,),
)
self.connector.commit()

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ toml
torch
torchdata
tqdm
mlos

0 comments on commit 7de96cb

Please sign in to comment.