-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
36 lines (33 loc) · 1.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from models.classifiers import *
from datasets.CIFAR import *
from datasets.LSUN import *
from datasets.SVHN import *
from utils.args import *
if __name__ == '__main__':
datasets = ['CIFAR10','CIFAR100']
NNModels = ['VGG','Resnet','WideResnet','Densenet_BC','Densenet']
for dataset in datasets:
if dataset == 'CIFAR10':
dm = CIFAR10DataModule()
max_epochs = 60
elif dataset == 'CIFAR100':
dm = CIFAR100DataModule()
max_epochs = 180
for NNModel in NNModels:
model_name = dataset + '_' + NNModel
model = globals()[model_name]()
modelpath = './workspace/model_ckpts/' + model_name + '/'
os.makedirs(modelpath, exist_ok=True)
checkpoint_callback=ModelCheckpoint(filepath=modelpath)
trainer=Trainer(checkpoint_callback=checkpoint_callback, gpus=1, num_nodes=1, max_epochs = max_epochs)
if os.path.isfile(modelpath + 'final.ckpt'):
model = model.load_from_checkpoint(checkpoint_path=modelpath + 'final.ckpt')
else:
trainer.fit(model, dm)
trainer.save_checkpoint(modelpath + 'final.ckpt')
trainer.test(model, datamodule = dm)