-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
141 lines (129 loc) · 4.12 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import argparse
from agent import Agent
#默认参数(Notes : some params are disable )
DefaultParam = {
"mode": "testing", # 模式 {"training","testing" }
"train_mode":"decision", #训练模式,{"segment":only train segment net,"decision": only train decision net, "total": both}
"epochs_num": 50,
"batch_size": 1,
"learn_rate": 0.001,
"momentum": 0.9, # 优化器参数(disable)
"data_dir": "../Datasets/KolektorSDD", # 数据路径
"checkPoint_dir": "checkpoint", # 模型保存路径
"Log_dir": "Log", # 日志打印路径
"valid_ratio": 0, # 数据集中用来验证的比例 (disable)
"valid_frequency": 3, # 每几个周期验证一次 (disable)
"save_frequency": 2, # 几个周期保存一次模型
"max_to_keep": 10, # 最多保存几个模型
"b_restore": True, # 导入参数
"b_saveNG": True, # 测试时是否保存错误的样本 (disable)
}
def parse_arguments():
"""
Parse the command line arguments of the program.
"""
parser = argparse.ArgumentParser(description='Train or test the CRNN model.')
parser.add_argument(
"--train_segment",
action="store_true",
help="Define if we wanna to train the segment net"
)
parser.add_argument(
"--train_decision",
action="store_true",
help="Define if we wanna to train the decision net"
)
parser.add_argument(
"--train_total",
action="store_true",
help="Define if we wanna to train the total net"
)
parser.add_argument(
"--pb",
action="store_true",
help="Define if we wanna to get the pbmodel"
)
parser.add_argument(
"--test",
action="store_true",
help="Define if we wanna test the model"
)
parser.add_argument(
"--anew",
action="store_true",
help="Define if we try to start from scratch instead of loading a checkpoint file from the save folder",
)
parser.add_argument(
"-vr",
"--valid_ratio",
type=float,
nargs="?",
help="How the data will be split between training and testing",
default=DefaultParam["valid_ratio"]
)
parser.add_argument(
"-ckpt",
"--checkPoint_dir",
type=str,
nargs="?",
help="The path where the pretrained model can be found or where the model will be saved",
default=DefaultParam["checkPoint_dir"]
)
parser.add_argument(
"-dd",
"--data_dir",
type=str,
nargs="?",
help="The path to the file containing the examples (training samples)",
default=DefaultParam["data_dir"]
)
parser.add_argument(
"-bs",
"--batch_size",
type=int,
nargs="?",
help="Size of a batch",
default=DefaultParam["batch_size"]
)
parser.add_argument(
"-en",
"--epochs_num",
type=int,
nargs="?",
help="How many iteration in training",
default=DefaultParam["epochs_num"]
)
return parser.parse_args()
def main():
"""
"""
#导入默认参数
param=DefaultParam
#从命令行更新参数
args = parse_arguments()
if not args.train_segment and not args.train_decision and not args.train_total and not args.test and not args.pb:
print("If we are not training, and not testing, what is the point?")
if args.train_segment:
param["mode"]="training"
param["train_mode"] = "segment"
if args.train_decision:
param["mode"]="training"
param["train_mode"] = "decision"
if args.train_total:
param["mode"]="training"
param["train_mode"] = "total"
if args.test :
param["mode"] = "testing"
if args.pb :
param["mode"] = "savePb"
if args.anew:
param["b_restore"] =False
param["data_dir"] = args.data_dir
param["valid_ratio"] = args.valid_ratio
param["batch_size"] = args.batch_size
param["epochs_num"] = args.epochs_num
param["checkPoint_dir"] = args.checkPoint_dir
agent=Agent(param)
agent.run()
if __name__ == '__main__':
main()