Skip to content

Commit

Permalink
Callback refactor (ML4GW#163)
Browse files Browse the repository at this point in the history
* refactor callbacks so they can be configured from the config file

* fix up configs
  • Loading branch information
EthanMarx authored Oct 25, 2024
1 parent 4d9319c commit d1b1207
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 18 deletions.
13 changes: 12 additions & 1 deletion amplfi/train/configs/flow/cbc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ trainer:
check_val_every_n_epoch: 1
log_every_n_steps: 20
benchmark: false
callbacks:
- class_path: amplfi.train.callbacks.ModelCheckpoint
init_args:
monitor: "valid_loss"
save_top_k: 5
save_last: true
auto_insert_metric_name : false
mode: "min"
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch

model:
class_path: amplfi.train.models.flow.FlowModel
init_args:
Expand Down Expand Up @@ -43,7 +55,6 @@ model:
groups: 8
patience: 10
factor: 0.1
save_top_k_models: 10
learning_rate: 3.7e-4
weight_decay: 0.0
data:
Expand Down
12 changes: 11 additions & 1 deletion amplfi/train/configs/flow/sg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ trainer:
check_val_every_n_epoch: 1
log_every_n_steps: 20
benchmark: false
callbacks:
- class_path: amplfi.train.callbacks.ModelCheckpoint
init_args:
monitor: "valid_loss"
save_top_k: 5
save_last: true
auto_insert_metric_name : false
mode: "min"
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
model:
class_path: amplfi.train.models.flow.FlowModel
init_args:
Expand All @@ -41,7 +52,6 @@ model:
groups: 16
patience: 10
factor: 0.1
save_top_k_models: 10
learning_rate: 3.77e-4
weight_decay: 0.0
data:
Expand Down
12 changes: 11 additions & 1 deletion amplfi/train/configs/similarity/cbc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ trainer:
check_val_every_n_epoch: 1
log_every_n_steps: 20
benchmark: false
callbacks:
- class_path: amplfi.train.callbacks.ModelCheckpoint
init_args:
monitor: "valid_loss"
save_top_k: 5
save_last: true
auto_insert_metric_name : false
mode: "min"
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
model:
class_path: amplfi.train.models.SimilarityModel
init_args:
Expand All @@ -32,7 +43,6 @@ model:
groups: 16
patience: 10
factor: 0.1
save_top_k_models: 10
learning_rate: 3.77e-4
weight_decay: 0.0
data:
Expand Down
15 changes: 0 additions & 15 deletions amplfi/train/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@

import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks import LearningRateMonitor
from ml4gw.transforms import ChannelWiseScaler

from ..callbacks import ModelCheckpoint

Tensor = torch.Tensor
Distribution = torch.distributions.Distribution

Expand All @@ -35,7 +32,6 @@ def __init__(
outdir: Path,
learning_rate: float,
weight_decay: float = 0.0,
save_top_k_models: int = 10,
patience: int = 10,
factor: float = 0.1,
checkpoint: Optional[Path] = None,
Expand Down Expand Up @@ -118,14 +114,3 @@ def configure_optimizers(self):
"optimizer": optimizer,
"lr_scheduler": {"scheduler": scheduler, "monitor": "valid_loss"},
}

def configure_callbacks(self):
checkpoint = ModelCheckpoint(
monitor="valid_loss",
save_top_k=self.hparams.save_top_k_models,
save_last=True,
auto_insert_metric_name=False,
mode="min",
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
return [checkpoint, lr_monitor]

0 comments on commit d1b1207

Please sign in to comment.