-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
89 lines (65 loc) · 2.74 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import models
import losses
from data import DataManager
import argparse
class Trainer(object):
def __init__(self, model, dm, criterion, optimizer, use_gpu, args: argparse.Namespace):
self.args = args
self.model = model
self.train_loader = dm.train_loader()
self.criterion, self.optimizer = criterion, optimizer
self.use_gpu = use_gpu
def do_train(self):
for epoch in range(args.max_epoch):
self.train_epoch(epoch=epoch)
torch.save(self.model.state_dict(), './pretrained_model.pth')
def train_epoch(self, epoch):
self.model.train()
running_loss = 0.0
running_acc = 0.0
for batch_idx, data in enumerate(self.train_loader):
loss, acc = self.forward_backward(data)
running_loss += loss
running_acc += acc
if (batch_idx + 1) % args.print_freq == 0:
print('[%d, %5d]\t loss: %.3f\t accuracy: %.3f'
% (epoch + 1, batch_idx + 1, running_loss / args.print_freq, running_acc / args.print_freq))
running_loss = 0.0
running_acc = 0.0
def forward_backward(self, data):
imgs, pids = self.parse_data_for_train(data)
if self.use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
logits = self.model(imgs)
loss = self.criterion(logits, pids)
self.optimizer.zero_grad(set_to_none=True)
loss.backward()
self.optimizer.step()
_, prediction = torch.max(logits, dim=1)
acc = (prediction == pids).sum() * 100 / len(pids)
return loss.item(), acc.item()
def parse_data_for_train(self, data):
imgs = data['img']
pids = data['pid']
return imgs, pids
@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("Trainer")
parser.add_argument("--max_epoch", type=int, default=20)
parser.add_argument("--print_freq", type=int, default=10)
return parent_parser
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser = DataManager.add_model_specific_args(parser)
parser = Trainer.add_model_specific_args(parser)
parser.add_argument("--use_gpu", type=bool, default=False)
args = parser.parse_args()
dm = DataManager(args)
model = models.OSNet(num_classes=dm.num_train_classes)
if args.use_gpu:
model = model.cuda()
criterion = losses.CrossEntropyLoss(num_classes=dm.num_train_classes, use_gpu=args.use_gpu)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
trainer = Trainer(model, dm, criterion, optimizer, args.use_gpu, args)
trainer.do_train()