-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathmain.py
221 lines (195 loc) · 10.3 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
#!/usr/bin/env python
# encoding: utf-8
'''
@author: Xu Yan
@file: main.py
@time: 2021/12/7 22:21
'''
import os
import yaml
import torch
import datetime
import importlib
import numpy as np
import pytorch_lightning as pl
from easydict import EasyDict
from argparse import ArgumentParser
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.profiler import SimpleProfiler
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from dataloader.dataset import get_model_class, get_collate_class
from dataloader.pc_dataset import get_pc_model_class
from pytorch_lightning.callbacks import LearningRateMonitor
import warnings
warnings.filterwarnings("ignore")
def load_yaml(file_name):
with open(file_name, 'r') as f:
try:
config = yaml.load(f, Loader=yaml.FullLoader)
except:
config = yaml.load(f)
return config
def parse_config():
parser = ArgumentParser()
# general
parser.add_argument('--gpu', type=int, nargs='+', default=(0,), help='specify gpu devices')
parser.add_argument("--seed", default=0, type=int)
parser.add_argument('--config_path', default='config/2DPASS-semantickitti.yaml')
# training
parser.add_argument('--log_dir', type=str, default='default', help='log location')
parser.add_argument('--monitor', type=str, default='val/mIoU', help='the maximum metric')
parser.add_argument('--stop_patience', type=int, default=50, help='patience for stop training')
parser.add_argument('--save_top_k', type=int, default=1, help='save top k checkpoints, use -1 to checkpoint every epoch')
parser.add_argument('--check_val_every_n_epoch', type=int, default=1, help='check_val_every_n_epoch')
parser.add_argument('--SWA', action='store_true', default=False, help='StochasticWeightAveraging')
parser.add_argument('--baseline_only', action='store_true', default=False, help='training without 2D')
# testing
parser.add_argument('--test', action='store_true', default=False, help='test mode')
parser.add_argument('--fine_tune', action='store_true', default=False, help='fine tune mode')
parser.add_argument('--pretrain2d', action='store_true', default=False, help='use pre-trained 2d network')
parser.add_argument('--num_vote', type=int, default=1, help='number of voting in the test')
parser.add_argument('--submit_to_server', action='store_true', default=False, help='submit on benchmark')
parser.add_argument('--checkpoint', type=str, default=None, help='load checkpoint')
# debug
parser.add_argument('--debug', default=False, action='store_true')
args = parser.parse_args()
config = load_yaml(args.config_path)
config.update(vars(args)) # override the configuration using the value in args
# voting test
if args.test:
config['dataset_params']['val_data_loader']['batch_size'] = args.num_vote
if args.num_vote > 1:
config['dataset_params']['val_data_loader']['rotate_aug'] = True
config['dataset_params']['val_data_loader']['transform_aug'] = True
if args.debug:
config['dataset_params']['val_data_loader']['batch_size'] = 2
config['dataset_params']['val_data_loader']['num_workers'] = 0
return EasyDict(config)
def build_loader(config):
pc_dataset = get_pc_model_class(config['dataset_params']['pc_dataset_type'])
dataset_type = get_model_class(config['dataset_params']['dataset_type'])
train_config = config['dataset_params']['train_data_loader']
val_config = config['dataset_params']['val_data_loader']
train_dataset_loader, val_dataset_loader, test_dataset_loader = None, None, None
if not config['test']:
train_pt_dataset = pc_dataset(config, data_path=train_config['data_path'], imageset='train')
val_pt_dataset = pc_dataset(config, data_path=val_config['data_path'], imageset='val')
train_dataset_loader = torch.utils.data.DataLoader(
dataset=dataset_type(train_pt_dataset, config, train_config),
batch_size=train_config["batch_size"],
collate_fn=get_collate_class(config['dataset_params']['collate_type']),
shuffle=train_config["shuffle"],
num_workers=train_config["num_workers"],
pin_memory=True,
drop_last=True
)
# config['dataset_params']['training_size'] = len(train_dataset_loader) * len(configs.gpu)
val_dataset_loader = torch.utils.data.DataLoader(
dataset=dataset_type(val_pt_dataset, config, val_config, num_vote=1),
batch_size=val_config["batch_size"],
collate_fn=get_collate_class(config['dataset_params']['collate_type']),
shuffle=val_config["shuffle"],
pin_memory=True,
num_workers=val_config["num_workers"]
)
else:
if config['submit_to_server']:
test_pt_dataset = pc_dataset(config, data_path=val_config['data_path'], imageset='test', num_vote=val_config["batch_size"])
test_dataset_loader = torch.utils.data.DataLoader(
dataset=dataset_type(test_pt_dataset, config, val_config, num_vote=val_config["batch_size"]),
batch_size=val_config["batch_size"],
collate_fn=get_collate_class(config['dataset_params']['collate_type']),
shuffle=val_config["shuffle"],
num_workers=val_config["num_workers"]
)
else:
val_pt_dataset = pc_dataset(config, data_path=val_config['data_path'], imageset='val', num_vote=val_config["batch_size"])
val_dataset_loader = torch.utils.data.DataLoader(
dataset=dataset_type(val_pt_dataset, config, val_config, num_vote=val_config["batch_size"]),
batch_size=val_config["batch_size"],
collate_fn=get_collate_class(config['dataset_params']['collate_type']),
shuffle=val_config["shuffle"],
num_workers=val_config["num_workers"]
)
return train_dataset_loader, val_dataset_loader, test_dataset_loader
if __name__ == '__main__':
# parameters
configs = parse_config()
print(configs)
# setting
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, configs.gpu))
num_gpu = len(configs.gpu)
# output path
log_folder = 'logs/' + configs['dataset_params']['pc_dataset_type']
tb_logger = pl_loggers.TensorBoardLogger(log_folder, name=configs.log_dir, default_hp_metric=False)
os.makedirs(f'{log_folder}/{configs.log_dir}', exist_ok=True)
profiler = SimpleProfiler(output_filename=f'{log_folder}/{configs.log_dir}/profiler.txt')
np.set_printoptions(precision=4, suppress=True)
# save the backup files
backup_dir = os.path.join(log_folder, configs.log_dir, 'backup_files_%s' % str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
if not configs['test']:
os.makedirs(backup_dir, exist_ok=True)
os.system('cp main.py {}'.format(backup_dir))
os.system('cp dataloader/dataset.py {}'.format(backup_dir))
os.system('cp dataloader/pc_dataset.py {}'.format(backup_dir))
os.system('cp {} {}'.format(configs.config_path, backup_dir))
os.system('cp network/base_model.py {}'.format(backup_dir))
os.system('cp network/baseline.py {}'.format(backup_dir))
os.system('cp {}.py {}'.format('network/' + configs['model_params']['model_architecture'], backup_dir))
# reproducibility
torch.manual_seed(configs.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
np.random.seed(configs.seed)
config_path = configs.config_path
train_dataset_loader, val_dataset_loader, test_dataset_loader = build_loader(configs)
model_file = importlib.import_module('network.' + configs['model_params']['model_architecture'])
my_model = model_file.get_model(configs)
pl.seed_everything(configs.seed)
checkpoint_callback = ModelCheckpoint(
monitor=configs.monitor,
mode='max',
save_last=True,
save_top_k=configs.save_top_k)
if configs.checkpoint is not None:
print('load pre-trained model...')
if configs.fine_tune or configs.test or configs.pretrain2d:
my_model = my_model.load_from_checkpoint(configs.checkpoint, config=configs, strict=(not configs.pretrain2d))
else:
# continue last training
my_model = my_model.load_from_checkpoint(configs.checkpoint)
if configs.SWA:
swa = [StochasticWeightAveraging(swa_epoch_start=configs.train_params.swa_epoch_start, annealing_epochs=1)]
else:
swa = []
if not configs.test:
# init trainer
print('Start training...')
trainer = pl.Trainer(gpus=[i for i in range(num_gpu)],
accelerator='ddp',
max_epochs=configs['train_params']['max_num_epochs'],
resume_from_checkpoint=configs.checkpoint if not configs.fine_tune and not configs.pretrain2d else None,
callbacks=[checkpoint_callback,
LearningRateMonitor(logging_interval='step'),
EarlyStopping(monitor=configs.monitor,
patience=configs.stop_patience,
mode='max',
verbose=True),
] + swa,
logger=tb_logger,
profiler=profiler,
check_val_every_n_epoch=configs.check_val_every_n_epoch,
gradient_clip_val=1,
accumulate_grad_batches=1
)
trainer.fit(my_model, train_dataset_loader, val_dataset_loader)
else:
print('Start testing...')
assert num_gpu == 1, 'only support single GPU testing!'
trainer = pl.Trainer(gpus=[i for i in range(num_gpu)],
accelerator='ddp',
resume_from_checkpoint=configs.checkpoint,
logger=tb_logger,
profiler=profiler)
trainer.test(my_model, test_dataset_loader if configs.submit_to_server else val_dataset_loader)