-
Notifications
You must be signed in to change notification settings - Fork 20
/
trainer.py
276 lines (244 loc) · 10.5 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import torch
import torchvision
from torchvision.models import vgg16
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.transforms import Compose, CenterCrop, Normalize, Scale, Resize, ToTensor, ToPILImage
from torch.optim.lr_scheduler import LambdaLR, StepLR
import numpy as np
import glob
import PIL.Image as Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import json
import pickle
from dataset import BatchData
from model import PreResNet, BiasLayer
from cifar import Cifar100
from exemplar import Exemplar
from copy import deepcopy
class Trainer:
def __init__(self, total_cls):
self.total_cls = total_cls
self.seen_cls = 0
self.dataset = Cifar100()
self.model = PreResNet(32,total_cls).cuda()
print(self.model)
self.model = nn.DataParallel(self.model, device_ids=[0,1])
self.bias_layer1 = BiasLayer().cuda()
self.bias_layer2 = BiasLayer().cuda()
self.bias_layer3 = BiasLayer().cuda()
self.bias_layer4 = BiasLayer().cuda()
self.bias_layer5 = BiasLayer().cuda()
self.bias_layers=[self.bias_layer1, self.bias_layer2, self.bias_layer3, self.bias_layer4, self.bias_layer5]
self.input_transform= Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32,padding=4),
ToTensor(),
Normalize([0.5071,0.4866,0.4409],[0.2673,0.2564,0.2762])])
self.input_transform_eval= Compose([
ToTensor(),
Normalize([0.5071,0.4866,0.4409],[0.2673,0.2564,0.2762])])
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print("Solver total trainable parameters : ", total_params)
def test(self, testdata):
print("test data number : ",len(testdata))
self.model.eval()
count = 0
correct = 0
wrong = 0
for i, (image, label) in enumerate(testdata):
image = image.cuda()
label = label.view(-1).cuda()
p = self.model(image)
p = self.bias_forward(p)
pred = p[:,:self.seen_cls].argmax(dim=-1)
correct += sum(pred == label).item()
wrong += sum(pred != label).item()
acc = correct / (wrong + correct)
print("Test Acc: {}".format(acc*100))
self.model.train()
print("---------------------------------------------")
return acc
def eval(self, criterion, evaldata):
self.model.eval()
losses = []
correct = 0
wrong = 0
for i, (image, label) in enumerate(evaldata):
image = image.cuda()
label = label.view(-1).cuda()
p = self.model(image)
p = self.bias_forward(p)
loss = criterion(p, label)
losses.append(loss.item())
pred = p[:,:self.seen_cls].argmax(dim=-1)
correct += sum(pred == label).item()
wrong += sum(pred != label).item()
print("Validation Loss: {}".format(np.mean(losses)))
print("Validation Acc: {}".format(100*correct/(correct+wrong)))
self.model.train()
return
def get_lr(self, optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def train(self, batch_size, epoches, lr, max_size):
total_cls = self.total_cls
criterion = nn.CrossEntropyLoss()
exemplar = Exemplar(max_size, total_cls)
previous_model = None
dataset = self.dataset
test_xs = []
test_ys = []
train_xs = []
train_ys = []
test_accs = []
for inc_i in range(dataset.batch_num):
print(f"Incremental num : {inc_i}")
train, val, test = dataset.getNextClasses(inc_i)
print(len(train), len(val), len(test))
train_x, train_y = zip(*train)
val_x, val_y = zip(*val)
test_x, test_y = zip(*test)
test_xs.extend(test_x)
test_ys.extend(test_y)
train_xs, train_ys = exemplar.get_exemplar_train()
train_xs.extend(train_x)
train_xs.extend(val_x)
train_ys.extend(train_y)
train_ys.extend(val_y)
train_data = DataLoader(BatchData(train_xs, train_ys, self.input_transform),
batch_size=batch_size, shuffle=True, drop_last=True)
val_data = DataLoader(BatchData(val_x, val_y, self.input_transform_eval),
batch_size=batch_size, shuffle=False)
test_data = DataLoader(BatchData(test_xs, test_ys, self.input_transform_eval),
batch_size=batch_size, shuffle=False)
optimizer = optim.SGD(self.model.parameters(), lr=lr, momentum=0.9, weight_decay=2e-4)
# scheduler = LambdaLR(optimizer, lr_lambda=adjust_cifar100)
scheduler = StepLR(optimizer, step_size=70, gamma=0.1)
# bias_optimizer = optim.SGD(self.bias_layers[inc_i].parameters(), lr=lr, momentum=0.9)
bias_optimizer = optim.Adam(self.bias_layers[inc_i].parameters(), lr=0.001)
# bias_scheduler = StepLR(bias_optimizer, step_size=70, gamma=0.1)
exemplar.update(total_cls//dataset.batch_num, (train_x, train_y), (val_x, val_y))
self.seen_cls = exemplar.get_cur_cls()
print("seen cls number : ", self.seen_cls)
val_xs, val_ys = exemplar.get_exemplar_val()
val_bias_data = DataLoader(BatchData(val_xs, val_ys, self.input_transform),
batch_size=100, shuffle=True, drop_last=False)
test_acc = []
for epoch in range(epoches):
print("---"*50)
print("Epoch", epoch)
scheduler.step()
cur_lr = self.get_lr(optimizer)
print("Current Learning Rate : ", cur_lr)
self.model.train()
for _ in range(len(self.bias_layers)):
self.bias_layers[_].eval()
if inc_i > 0:
self.stage1_distill(train_data, criterion, optimizer)
else:
self.stage1(train_data, criterion, optimizer)
acc = self.test(test_data)
if inc_i > 0:
for epoch in range(epoches):
# bias_scheduler.step()
self.model.eval()
for _ in range(len(self.bias_layers)):
self.bias_layers[_].train()
self.stage2(val_bias_data, criterion, bias_optimizer)
if epoch % 50 == 0:
acc = self.test(test_data)
test_acc.append(acc)
for i, layer in enumerate(self.bias_layers):
layer.printParam(i)
self.previous_model = deepcopy(self.model)
acc = self.test(test_data)
test_acc.append(acc)
test_accs.append(max(test_acc))
print(test_accs)
def bias_forward(self, input):
in1 = input[:, :20]
in2 = input[:, 20:40]
in3 = input[:, 40:60]
in4 = input[:, 60:80]
in5 = input[:, 80:100]
out1 = self.bias_layer1(in1)
out2 = self.bias_layer2(in2)
out3 = self.bias_layer3(in3)
out4 = self.bias_layer4(in4)
out5 = self.bias_layer5(in5)
return torch.cat([out1, out2, out3, out4, out5], dim = 1)
def stage1(self, train_data, criterion, optimizer):
print("Training ... ")
losses = []
for i, (image, label) in enumerate(tqdm(train_data)):
image = image.cuda()
label = label.view(-1).cuda()
p = self.model(image)
p = self.bias_forward(p)
loss = criterion(p[:,:self.seen_cls], label)
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
losses.append(loss.item())
print("stage1 loss :", np.mean(losses))
def stage1_distill(self, train_data, criterion, optimizer):
print("Training ... ")
distill_losses = []
ce_losses = []
T = 2
alpha = (self.seen_cls - 20)/ self.seen_cls
print("classification proportion 1-alpha = ", 1-alpha)
for i, (image, label) in enumerate(tqdm(train_data)):
image = image.cuda()
label = label.view(-1).cuda()
p = self.model(image)
p = self.bias_forward(p)
with torch.no_grad():
pre_p = self.previous_model(image)
pre_p = self.bias_forward(pre_p)
pre_p = F.softmax(pre_p[:,:self.seen_cls-20]/T, dim=1)
logp = F.log_softmax(p[:,:self.seen_cls-20]/T, dim=1)
loss_soft_target = -torch.mean(torch.sum(pre_p * logp, dim=1))
loss_hard_target = nn.CrossEntropyLoss()(p[:,:self.seen_cls], label)
loss = loss_soft_target * T * T + (1-alpha) * loss_hard_target
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
distill_losses.append(loss_soft_target.item())
ce_losses.append(loss_hard_target.item())
print("stage1 distill loss :", np.mean(distill_losses), "ce loss :", np.mean(ce_losses))
def stage1(self, train_data, criterion, optimizer):
print("Training ... ")
losses = []
for i, (image, label) in enumerate(tqdm(train_data)):
image = image.cuda()
label = label.view(-1).cuda()
p = self.model(image)
p = self.bias_forward(p)
loss = criterion(p[:,:self.seen_cls], label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
print("stage1 loss :", np.mean(losses))
def stage2(self, val_bias_data, criterion, optimizer):
print("Evaluating ... ")
losses = []
for i, (image, label) in enumerate(tqdm(val_bias_data)):
image = image.cuda()
label = label.view(-1).cuda()
p = self.model(image)
p = self.bias_forward(p)
loss = criterion(p[:,:self.seen_cls], label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
print("stage2 loss :", np.mean(losses))