diff --git a/src/trainer.py b/src/trainer.py index fcc0bad..e49bb82 100755 --- a/src/trainer.py +++ b/src/trainer.py @@ -191,6 +191,29 @@ def load_model(self, advanced=False): advanced=advanced, ) + def reset_learning_rate(self, new_lr=None): + """ + Utility to allow custom control of learning rate during training, + after loading a model from a checkpoint that would otherwise inherit + the learning rate from the previous epoch via the checkpoint. + + Args: + ---- + new_lr : TYPE, optional + DESCRIPTION. The default is None. + + Returns + ---- + None. + + """ + if new_lr is None: + new_lr = self.hp["learning_rate"] + for param_group in self.optimizer.param_groups: + param_group["lr"] = new_lr + # Recreate the scheduler with the new learning rate + self.configure_scheduler() + def save_model(self): """ Save the current model to checkpoint_dir.