-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
379 lines (329 loc) · 18.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
import traceback
import argparse
import sys
import os
import csv
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
import vgg
from torch.utils.data import DataLoader
from torch.optim import SGD, Adam
from hypergrad import SGD_HD, Adam_HD
from op_adam_lop_adam import op_Adam_lop_Adam
from op_sgd_lop_adam import op_Sgd_lop_Adam
from op_adam_lop_sgdn import op_Adam_lop_Sgdn
from op_sgd_lop_sgdn import op_Sgd_lop_Sgdn
# =======================================================================
# LOGREG AND MLP MODELS
# =======================================================================
class LogReg(nn.Module):
def __init__(self, input_dim, output_dim):
super(LogReg, self).__init__()
self._input_dim = input_dim
self.lin1 = nn.Linear(input_dim, output_dim)
def forward(self, x):
x = x.view(-1, self._input_dim)
x = self.lin1(x)
return x
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(MLP, self).__init__()
self._input_dim = input_dim
self.lin1 = nn.Linear(input_dim, hidden_dim)
self.lin2 = nn.Linear(hidden_dim, hidden_dim)
self.lin3 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = x.view(-1, self._input_dim)
x = F.relu(self.lin1(x))
x = F.relu(self.lin2(x))
x = self.lin3(x)
return x
def train(opt, log_func=None):
torch.manual_seed(opt.seed)
if opt.cuda:
torch.cuda.set_device(opt.device)
torch.cuda.manual_seed(opt.seed)
torch.backends.cudnn.enabled = True
# =============================================================================
# SETUP MODEL, DATASET, DATALOADER, OPTIMIZER
# =============================================================================
if opt.model == 'logreg':
model = LogReg(28 * 28, 10)
elif opt.model == 'mlp':
model = MLP(28 * 28, 1000, 10)
elif opt.model == 'vgg':
model = vgg.vgg16_bn()
if opt.parallel:
model.features = torch.nn.DataParallel(model.features)
else:
raise Exception('Unknown model: {}'.format(opt.model))
if opt.cuda:
model = model.cuda()
if opt.model == 'logreg' or opt.model == 'mlp':
task = 'MNIST'
train_loader = DataLoader(
datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=opt.batchSize, shuffle=True)
valid_loader = DataLoader(
datasets.MNIST('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=opt.batchSize, shuffle=False)
elif opt.model == 'vgg':
task = 'CIFAR10'
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=opt.batchSize, shuffle=True,
num_workers=opt.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=opt.batchSize, shuffle=False,
num_workers=opt.workers, pin_memory=True)
else:
raise Exception('Unknown model: {}'.format(opt.model))
if opt.method == 'sgd':
optimizer = SGD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay)
elif opt.method == 'sgd_Hd':
optimizer = SGD_HD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, hypergrad_lr=opt.beta)
#elif opt.method == 'op_sgd_lop_adam':
# optimizer = op_Sgd_lop_Adam(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, momentum=0, nesterov=False, hypergrad_lr=opt.beta)
elif opt.method == 'op_sgd_lop_sgdn':
optimizer = op_Sgd_lop_Sgdn(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, momentum=0, nesterov=False, hypergrad_lr=opt.beta)
elif opt.method == 'sgdn':
optimizer = SGD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, momentum=opt.mu, nesterov=True) # mu required
elif opt.method == 'sgdn_Hd':
optimizer = SGD_HD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, momentum=opt.mu, nesterov=True, hypergrad_lr=opt.beta) # mu required
#elif opt.method == 'op_sgdn_lop_adam':
# optimizer = op_Sgd_lop_Adam(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, momentum=opt.mu, nesterov=True, hypergrad_lr=opt.beta) #mu required
elif opt.method == 'op_sgdn_lop_sgdn':
optimizer = op_Sgd_lop_Sgdn(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, momentum=opt.mu, nesterov=True, hypergrad_lr=opt.beta) #mu required
elif opt.method == 'adam':
optimizer = Adam(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay)
elif opt.method == 'adam_Hd':
optimizer = Adam_HD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, hypergrad_lr=opt.beta)
elif opt.method == 'op_adam_lop_adam':
optimizer = op_Adam_lop_Adam(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, hypergrad_lr=opt.beta)
#elif opt.method == 'op_adam_lop_sgdn':
# optimizer = op_Adam_lop_Sgdn(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, hypergrad_lr=opt.beta)
else:
raise Exception('Unknown method: {}'.format(opt.method))
if not opt.silent:
print('Task: {}, Model: {}, Method: {}'.format(task, opt.model, opt.method))
# =============================================================================
# Saving & Loading a General Checkpoint for Resuming Training
# =============================================================================
if(opt.continue_training):
begin_epoch = opt.begin_epoch
model_load_path ='{}_{}_{:+.0e}_epochs{}.pth'.format(opt.model, opt.method, opt.beta, begin_epoch)
checkpoint = torch.load(model_load_path)
if(checkpoint['epoch']==begin_epoch):
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
begin_iteration = checkpoint['iteration']
time_already = checkpoint['time']
model_save_path = '{}_{}_{:+.0e}_epochs{}.pth'.format(opt.model, opt.method, opt.beta, begin_epoch+opt.epochs)
else:
print("Provide the correct number of epochs after which to continue training")
quit()
else:
begin_epoch = 0
begin_iteration = 0
time_already = 0
model_save_path = '{}_{}_{:+.0e}_epochs{}.pth'.format(opt.model, opt.method, opt.beta, begin_epoch+opt.epochs)
model.eval()
for batch_id, (data, target) in enumerate(train_loader):
data, target = Variable(data), Variable(target)
if opt.cuda:
data, target = data.cuda(), target.cuda()
output = model(data)
loss = F.cross_entropy(output, target)
loss = loss.data
break
valid_loss = 0
with torch.no_grad():
for data, target in valid_loader:
data, target = Variable(data), Variable(target)
if opt.cuda:
data, target = data.cuda(), target.cuda()
output = model(data)
valid_loss += F.cross_entropy(output, target, size_average=False).data
valid_loss /= len(valid_loader.dataset)
if(not opt.continue_training) and log_func is not None:
log_func(0, 0, 0, loss.item(), loss.item(), valid_loss.item(), opt.alpha_0, opt.alpha_0, opt.beta)
# =============================================================================
# TRAINING LOOP
# =============================================================================
time_start = time.time()
epoch = 1
iteration = 1
done = False
# Epoch start
while not done:
# -------------------------------------------------------------------------
# EPOCH START
# -------------------------------------------------------------------------
model.train()
loss_epoch = 0
alpha_epoch = 0
for batch_id, (data, target) in enumerate(train_loader):
data, target = Variable(data), Variable(target)
if opt.cuda:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
loss = loss.data
loss_epoch += loss
alpha = optimizer.param_groups[0]['lr']
alpha_epoch += alpha
# alpha, alpha_epoch type float in case of regular optimizers, tensor in case of HD counterparts
if isinstance(alpha, torch.Tensor):
alpha = alpha.item()
alpha_epoch = alpha_epoch.item()
# Early stopping in case lossThreshold provided
if opt.lossThreshold >= 0:
if loss <= opt.lossThreshold:
print('Early stopping: loss <= {}'.format(opt.lossThreshold))
done = True
break
# Early stopping in case number of iterations provided
if opt.iterations != 0:
if iteration + 1 > opt.iterations:
print('Early stopping: iteration > {}'.format(opt.iterations))
done = True
file = "iterations.csv"
with open(file,"a") as fl:
writer = csv.writer(fl)
writer.writerow('{} | {} | {}'.format(opt.alpha_0, opt.method, iteration))
fl.close()
break
# -------------------------------------------------------------------------
# ON EPOCH END (validation)
# -------------------------------------------------------------------------
if batch_id + 1 >= len(train_loader):
loss_epoch /= len(train_loader)
alpha_epoch /= len(train_loader)
model.eval()
valid_loss = 0
with torch.no_grad():
for data, target in valid_loader:
data, target = Variable(data), Variable(target)
if opt.cuda:
data, target = data.cuda(), target.cuda()
output = model(data)
valid_loss += F.cross_entropy(output, target, size_average=False).data
valid_loss /= len(valid_loader.dataset)
if log_func is not None:
log_func(begin_epoch + epoch, begin_iteration + iteration, time.time() - time_start + time_already, loss.item(), loss_epoch.item(), valid_loss.item(), alpha, alpha_epoch, opt.beta)
# -------------------------------------------------------------------------
# ELSE CONTINUE EPOCH
# -------------------------------------------------------------------------
else:
if log_func is not None:
log_func(begin_epoch + epoch, begin_iteration + iteration, time.time() - time_start + time_already, loss.item(), float('nan'), float('nan'), alpha, float('nan'), opt.beta)
iteration += 1
# -------------------------------------------------------------------------
# IF DESIRED NUMBER OF EPOCHS COMPLETED (checkpoint)
# -------------------------------------------------------------------------
if opt.epochs != 0:
if epoch + 1 > opt.epochs:
print('Early stopping: epoch > {}'.format(opt.epochs))
done = True
torch.save({'epoch': begin_epoch + epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'iteration': begin_iteration + iteration -1, 'time':time.time() - time_start + time_already }, model_save_path)
epoch += 1
return loss, iteration
def main():
# =======================================================================
# INPUT ARGUMENTS
# =======================================================================
try:
parser = argparse.ArgumentParser(description='Hypergradient descent PyTorch tests', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--cuda', help='use CUDA', action='store_true')
parser.add_argument('--device', help='selected CUDA device', default=0, type=int)
parser.add_argument('--seed', help='random seed', default=1, type=int)
parser.add_argument('--dir', help='directory to write the output files', default='test', type=str)
parser.add_argument('--model', help='model (logreg, mlp, vgg)', default='logreg', type=str)
parser.add_argument('--method', help='method (sgd, sgd_Hd, op_sgd_lop_sgdn, sgdn, sgdn_Hd, op_sgdn_lop_sgdn, adam, adam_Hd, op_adam_lop_adam)', default='adam_Hd', type=str)
parser.add_argument('--alpha_0', help='initial learning rate', default=0.001, type=float)
parser.add_argument('--beta', help='learning learning rate', default=0.000001, type=float)
parser.add_argument('--mu', help='momentum', default=0.9, type=float)
parser.add_argument('--weightDecay', help='regularization', default=0.0001, type=float)
parser.add_argument('--batchSize', help='minibatch size', default=128, type=int)
parser.add_argument('--epochs', help='stop after this many epochs (0: disregard)', default=1, type=int) # Number of epochs to train from this point
parser.add_argument('--iterations', help='stop after this many iterations (0: disregard)', default=0, type=int) # Number of iterations to train from this point
parser.add_argument('--lossThreshold', help='stop after reaching this loss (0: disregard)', default=0, type=float)
parser.add_argument('--silent', help='do not print output', action='store_true')
parser.add_argument('--workers', help='number of data loading workers', default=4, type=int)
parser.add_argument('--parallel', help='parallelize', action='store_true')
parser.add_argument('--save', help='do not save output to file', action='store_true')
# -------------------------------------------------------------------------
# If using saved model;
# -------------------------------------------------------------------------
parser.add_argument('--continue_training', help='whether to continue training or start new', action='store_true')
parser.add_argument('--begin_epoch', help = 'Number of epochs after which to resume training', default=0, type=int) # Number of epochs model has been trained in case of continue_training
## The checkpoint is saved as "opt.model_opt.method_opt.beta_epochs{X}.pth" where X = Number of epochs the model has been trained.
opt = parser.parse_args()
torch.manual_seed(opt.seed)
if opt.cuda:
torch.cuda.set_device(opt.device)
torch.cuda.manual_seed(opt.seed)
torch.backends.cudnn.enabled = True
## Device configs
if torch.cuda.is_available():
a = torch.cuda.current_device()
print("Running on : {}".format(a))
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(a))
# -------------------------------------------------------------------------
# Results file
# -------------------------------------------------------------------------
if not opt.save:
def log_func(epoch, iteration, time_spent, loss, loss_epoch, valid_loss, alpha, alpha_epoch, beta):
if not opt.silent:
print('{} | {} | Epoch: {} | Iter: {} | Time: {:+.3e} | Loss: {:+.3e} | Valid. loss: {:+.3e} | Alpha: {:+.3e} | Beta: {:+.3e}'.format(opt.model, opt.method, epoch, iteration, time_spent, loss, valid_loss, alpha, beta))
train(opt, log_func)
else:
file_name = '{}/{}/{:+.0e}_{:+.0e}/{}.csv'.format(opt.dir, opt.model, opt.alpha_0, opt.beta, opt.method)
if not opt.silent:
print('Output file: {}'.format(file_name))
if os.path.isfile(file_name):
print('File with previous results exists, appending to that file...')
else:
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'a') as f:
writer = csv.writer(f)
if not opt.continue_training:
writer.writerow(['Epoch', 'Iteration', 'Time', 'Loss', 'LossEpoch', 'ValidLossEpoch', 'Alpha', 'AlphaEpoch', 'Beta'])
def log_func(epoch, iteration, time_spent, loss, loss_epoch, valid_loss, alpha, alpha_epoch, beta):
writer.writerow([epoch, iteration, time_spent, loss, loss_epoch, valid_loss, alpha, alpha_epoch, beta])
if not opt.silent:
print('{} | {} | Epoch: {} | Iter: {} | Time: {:+.3e} | Loss: {:+.3e} | Valid. loss: {:+.3e} | Alpha: {:+.3e} | Beta: {:+.3e}'.format(opt.model, opt.method, epoch, iteration, time_spent, loss, valid_loss, alpha, beta))
train(opt, log_func)
except KeyboardInterrupt:
print('Stopped')
except Exception:
traceback.print_exc(file=sys.stdout)
sys.exit(0)
if __name__ == "__main__":
main()