-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathmain_train.py
55 lines (42 loc) · 1.31 KB
/
main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
Copyright (c) 2018. All rights reserved.
Created by Resnick Xing on 2018/5/10
"""
import os
from experiments.data_loaders.standard_loader import DataLoader
# from infers.simple_mnist_infer import SimpleMnistInfer
from perception.models.AttnUnet import SegmentionModel
from perception.trainers.segmention_trainer import SegmentionTrainer
from configs.utils.config_utils import process_config
import numpy as np
def main_train():
"""
训练模型
:return:
"""
print('[INFO] Reading Configs...')
config = None
try:
config = process_config('configs/segmention_config.json')
except Exception as e:
print('[Exception] Config Error, %s' % e)
exit(0)
# np.random.seed(47) # 固定随机数
print('[INFO] Preparing Data...')
dataloader = DataLoader(config=config)
dataloader.prepare_dataset()
train_imgs,train_gt=dataloader.get_train_data()
val_imgs,val_gt=dataloader.get_val_data()
print('[INFO] Building Model...')
model = SegmentionModel(config=config)
#
print('[INFO] Training...')
trainer = SegmentionTrainer(
model=model.model,
data=[train_imgs,train_gt,val_imgs,val_gt],
config=config)
trainer.train()
print('[INFO] Finishing...')
if __name__ == '__main__':
main_train()
# test_main()