-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
150 lines (129 loc) · 5.68 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import argparse
import util
import os
import datetime
import random
import mlconfig
import loss
import models
import dataset
import shutil
from evaluator import Evaluator
from trainer import Trainer
# ArgParse
parser = argparse.ArgumentParser(description='Normalized Loss Functions for Deep Learning with Noisy Labels')
# Training
parser.add_argument('--resume', action='store_true', default=False)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--config_path', type=str, default='configs')
parser.add_argument('--version', type=str, default='ce')
parser.add_argument('--exp_name', type=str, default="run1")
parser.add_argument('--load_model', action='store_true', default=False)
parser.add_argument('--data_parallel', action='store_true', default=False)
parser.add_argument('--asym', action='store_true', default=False)
parser.add_argument('--noise_rate', type=float, default=0.0)
parser.add_argument('--removal_rate', type=float, default=0.0)
parser.add_argument('--repeat_rate', type=float, default=0.0)
parser.add_argument('--model_name', type=str, default="ToyModel")
args = parser.parse_args()
# Set up
if args.exp_name == '' or args.exp_name is None:
args.exp_name = 'exp_' + datetime.datetime.now()
exp_path = os.path.join(args.exp_name, args.version)
log_file_path = os.path.join(exp_path, args.version)
checkpoint_path = os.path.join(exp_path, 'checkpoints')
checkpoint_path_file = os.path.join(checkpoint_path, args.version)
util.build_dirs(exp_path)
util.build_dirs(checkpoint_path)
logger = util.setup_logger(name=args.version, log_file=log_file_path + ".log")
for arg in vars(args):
logger.info("%s: %s" % (arg, getattr(args, arg)))
random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
device = torch.device('cuda')
logger.info("Using CUDA!")
device_list = [torch.cuda.get_device_name(i) for i in range(0, torch.cuda.device_count())]
logger.info("GPU List: %s" % (device_list))
else:
device = torch.device('cpu')
logger.info("PyTorch Version: %s" % (torch.__version__))
config_file = os.path.join(args.config_path, args.version) + '.yaml'
config = mlconfig.load(config_file)
config.model.name = args.model_name
config.set_immutable()
shutil.copyfile(config_file, os.path.join(exp_path, args.version+'.yaml'))
for key in config:
logger.info("%s: %s" % (key, config[key]))
def train(starting_epoch, model, optimizer, scheduler, criterion, trainer, evaluator, ENV, model_name):
for epoch in range(starting_epoch, config.epochs):
logger.info("="*20 + "Training" + "="*20)
# Train
ENV['global_step'] = trainer.train(epoch, ENV['global_step'], model, optimizer, criterion)
scheduler.step()
# Eval
logger.info("="*20 + "Eval" + "="*20)
evaluator.eval(epoch, ENV['global_step'], model, torch.nn.CrossEntropyLoss(), args.noise_rate, args.removal_rate, args.repeat_rate, model_name)
payload = ('Eval Loss:%.4f\tEval acc: %.2f' % (evaluator.loss_meters.avg, evaluator.acc_meters.avg*100))
logger.info(payload)
ENV['train_history'].append(trainer.acc_meters.avg*100)
ENV['eval_history'].append(evaluator.acc_meters.avg*100)
ENV['curren_acc'] = evaluator.acc_meters.avg*100
ENV['best_acc'] = max(ENV['curren_acc'], ENV['best_acc'])
# Reset Stats
trainer._reset_stats()
evaluator._reset_stats()
# Save Model
target_model = model.module if args.data_parallel else model
util.save_model(ENV=ENV,
epoch=epoch,
model=target_model,
optimizer=optimizer,
scheduler=scheduler,
filename=checkpoint_path_file)
logger.info('Model Saved at %s', checkpoint_path_file)
return
def main():
if config.dataset.name == 'DatasetGenerator':
data_loader = config.dataset(seed=args.seed, noise_rate=args.noise_rate, removal_rate=args.removal_rate, repeat_rate=args.repeat_rate, asym=args.asym)
else:
data_loader = config.dataset()
model = config.model()
if isinstance(data_loader, dataset.Clothing1MDatasetLoader):
model.fc = torch.nn.Linear(2048, 14)
model = model.to(device)
data_loader = data_loader.getDataLoader()
logger.info("param size = %fMB", util.count_parameters_in_MB(model))
if args.data_parallel:
model = torch.nn.DataParallel(model)
optimizer = config.optimizer(model.parameters())
scheduler = config.scheduler(optimizer)
if config.criterion.name == 'NLNL':
criterion = config.criterion(train_loader=data_loader['train_dataset'])
else:
criterion = config.criterion()
trainer = Trainer(data_loader['train_dataset'], logger, config)
evaluator = Evaluator(data_loader['test_dataset'], logger, config)
starting_epoch = 0
ENV = {'global_step': 0,
'best_acc': 0.0,
'current_acc': 0.0,
'train_history': [],
'eval_history': []}
if args.load_model:
checkpoint = util.load_model(filename=checkpoint_path_file,
model=model,
optimizer=optimizer,
scheduler=scheduler)
starting_epoch = checkpoint['epoch']
ENV = checkpoint['ENV']
trainer.global_step = ENV['global_step']
logger.info("File %s loaded!" % (checkpoint_path_file))
print("Model selected: " + args.model_name)
train(starting_epoch, model, optimizer, scheduler, criterion, trainer, evaluator, ENV, args.model_name)
return
if __name__ == '__main__':
main()