-
Notifications
You must be signed in to change notification settings - Fork 0
/
exe_train_models.py
72 lines (54 loc) · 1.69 KB
/
exe_train_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from DLBio import pt_run_parallel
import copy
import argparse
AVAILABLE_GPUS = [
0,
1,
2,
3
]
USED_MODELS = [
'smp_resnet18',
# 'smp_resnet50',
# 'smp_resnet152',
# 'smp_mobilenet_v2'
]
DEFAULT_KWARGS = {
'use_rgb': None, # translates to use_rgb = True
'dataset': 'simulation',
'seed': 0
}
POST_FIX = 'sim_15032021'
def get_options():
parser = argparse.ArgumentParser()
parser.add_argument('--type', type=str, default='models')
return parser.parse_args()
def run():
options = get_options()
if options.type == 'models':
param_generator = different_models_pg
else:
raise ValueError(f'unknown type {options.type}')
make_object = pt_run_parallel.MakeObject(TrainingProcess)
pt_run_parallel.run(param_generator(), make_object,
available_gpus=AVAILABLE_GPUS
)
class TrainingProcess(pt_run_parallel.ITrainingProcess):
def __init__(self, **kwargs):
super(TrainingProcess, self).__init__(**kwargs)
self.__name__ = f'train_model_{kwargs["model_type"]}'
self.module_name = 'run_training.py'
self.kwargs = kwargs
# -----------------------------------------------------------------------------
# ---------------------PARAM GENERATORS----------------------------------------
# -----------------------------------------------------------------------------
def different_models_pg():
for model_type in USED_MODELS:
kwargs = copy.deepcopy(DEFAULT_KWARGS)
kwargs.update({
'model_type': model_type,
'folder': model_type + '_' + POST_FIX,
})
yield kwargs
if __name__ == "__main__":
run()