Skip to content

Commit

Permalink
ability to control learning rate during training
Browse files Browse the repository at this point in the history
Added reset_learning_rate() method to Trainer class. Intended to allow training scripts, such as a customized train_prod, to control learning rate at particular epochs or folds.
  • Loading branch information
alanngnet committed Jul 24, 2024
1 parent e2efff7 commit 3939a7d
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,29 @@ def load_model(self, advanced=False):
advanced=advanced,
)

def reset_learning_rate(self, new_lr=None):
"""
Utility to allow custom control of learning rate during training,
after loading a model from a checkpoint that would otherwise inherit
the learning rate from the previous epoch via the checkpoint.
Args:
----
new_lr : TYPE, optional
DESCRIPTION. The default is None.
Returns
----
None.
"""
if new_lr is None:
new_lr = self.hp["learning_rate"]
for param_group in self.optimizer.param_groups:
param_group["lr"] = new_lr
# Recreate the scheduler with the new learning rate
self.configure_scheduler()

def save_model(self):
"""
Save the current model to checkpoint_dir.
Expand Down

0 comments on commit 3939a7d

Please sign in to comment.