-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_classification_2.py
52 lines (39 loc) · 1.69 KB
/
main_classification_2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from comet_ml import experiment
from data_loader.uts_classification_data_loader import UtsClassificationDataLoader
from models.uts_classification_model import UtsClassificationModel
from trainers.uts_classification_trainer import UtsClassificationTrainer
from evaluater.uts_classification_evaluater import UtsClassificationEvaluater
from utils.config import process_config_UtsClassification
from utils.dirs import create_dirs
from utils.utils import get_args
import os
import time
def main():
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
# capture the config path from the run arguments
# then process the json configuration file
# try:
args = get_args()
config = process_config_UtsClassification(args.config)
# except:
# print("missing or invalid arguments")
# exit(0)
# create the experiments dirs
create_dirs([config.callbacks.tensorboard_log_dir, config.callbacks.checkpoint_dir,
config.log_dir, config.result_dir])
print('Create the data generator.')
data_loader = UtsClassificationDataLoader(config)
print('Create the model.')
model = UtsClassificationModel(config, data_loader.get_inputshape(), data_loader.get_nbclasses())
print('Create the trainer')
trainer = UtsClassificationTrainer(model.model, data_loader.get_train_data(), config)
print('Start training the model.')
trainer.train()
print('Create the evaluater.')
evaluater = UtsClassificationEvaluater(trainer.best_model, data_loader.get_test_data(), data_loader.get_nbclasses(),
config)
print('Start evaluating the model.')
evaluater.evluate()
print('done')
if __name__ == '__main__':
main()