-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
33 lines (26 loc) · 963 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from data import FlowersDataModule
from model import Resnet50Model
def main():
flower_data = FlowersDataModule()
resnet_model = Resnet50Model()
checkpoint_callback = ModelCheckpoint(
dirpath="./models", monitor="val_loss", mode="min"
)
early_stopping_callback = EarlyStopping(
monitor="val_loss", patience=3, verbose=True, mode="min"
)
trainer = pl.Trainer(
default_root_dir="logs",
accelerator=("gpu" if torch.cuda.is_available() else "cpu"),
max_epochs=5,
fast_dev_run=False,
logger=pl.loggers.TensorBoardLogger("logs/", name="resnet", version=1),
callbacks=[checkpoint_callback, early_stopping_callback],
)
trainer.fit(resnet_model, flower_data)
if __name__ == "__main__":
main()