-
Notifications
You must be signed in to change notification settings - Fork 6
/
main.py
105 lines (77 loc) · 3.48 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os, sys, time
import numpy as np
import torch
import torch.nn as nn
import torchvision
PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, PATH + '/../..')
from option import get_args
from learning.trainer import Trainer
from learning.evaluator import Evaluator
from utils import get_model, make_optimizer, make_scheduler, make_dataloader, plot_learning_curves
def main():
args = get_args()
torch.manual_seed(args.seed)
shape = (224,224,3)
""" define dataloader """
train_loader, valid_loader, test_loader = make_dataloader(args)
""" define model architecture """
model = get_model(args, shape, args.num_classes)
if torch.cuda.device_count() >= 1:
print('Model pushed to {} GPU(s), type {}.'.format(torch.cuda.device_count(), torch.cuda.get_device_name(0)))
model = model.cuda()
else:
raise ValueError('CPU training is not supported')
""" define loss criterion """
criterion = nn.CrossEntropyLoss().cuda()
""" define optimizer """
optimizer = make_optimizer(args, model)
""" define learning rate scheduler """
scheduler = make_scheduler(args, optimizer)
""" define loss scaler for automatic mixed precision """
scaler = torch.cuda.amp.GradScaler()
""" define trainer, evaluator, result_dictionary """
result_dict = {'args':vars(args), 'epoch':[], 'train_loss' : [], 'train_acc' : [], 'val_loss' : [], 'val_acc' : [], 'test_acc':[]}
trainer = Trainer(model, criterion, optimizer, scheduler, scaler)
evaluator = Evaluator(model, criterion)
train_time_list = []
valid_time_list = []
if args.evaluate:
""" load model checkpoint """
model.load()
result_dict = evaluator.test(test_loader, args, result_dict)
else:
evaluator.save(result_dict)
best_val_acc = 0.0
""" define training loop """
for epoch in range(args.epochs):
result_dict['epoch'] = epoch
torch.cuda.synchronize()
tic1 = time.time()
result_dict = trainer.train(train_loader, epoch, args, result_dict)
torch.cuda.synchronize()
tic2 = time.time()
train_time_list.append(tic2 - tic1)
torch.cuda.synchronize()
tic3 = time.time()
result_dict = evaluator.evaluate(valid_loader, epoch, args, result_dict)
torch.cuda.synchronize()
tic4 = time.time()
valid_time_list.append(tic4 - tic3)
if result_dict['val_acc'][-1] > best_val_acc:
print("{} epoch, best epoch was updated! {}%".format(epoch, result_dict['val_acc'][-1]))
best_val_acc = result_dict['val_acc'][-1]
model.save(checkpoint_name='best_model')
evaluator.save(result_dict)
plot_learning_curves(result_dict, epoch, args)
result_dict = evaluator.test(test_loader, args, result_dict)
evaluator.save(result_dict)
""" calculate test accuracy using best model """
model.load(checkpoint_name='best_model')
result_dict = evaluator.test(test_loader, args, result_dict)
evaluator.save(result_dict)
print(result_dict)
np.savetxt(os.path.join(model.checkpoint_dir, model.checkpoint_name, 'train_time_amp.csv'), train_time_list, delimiter=',', fmt='%s')
np.savetxt(os.path.join(model.checkpoint_dir, model.checkpoint_name, 'valid_time_amp.csv'), valid_time_list, delimiter=',', fmt='%s')
if __name__ == '__main__':
main()