@@ -18,6 +18,20 @@ def main():
18
18
parser .add_argument (
19
19
"test_files_path" , metavar = "path" , help = "Path to the test files" ,
20
20
)
21
+ parser .add_argument (
22
+ "--checkpoints_save_path" ,
23
+ dest = "checkpoints_save_path" ,
24
+ metavar = "path" ,
25
+ type = str ,
26
+ help = "Path where the checkpoints should be saved." ,
27
+ )
28
+ parser .add_argument (
29
+ "--last_checkpoint_path" ,
30
+ dest = "last_checkpoint_path" ,
31
+ metavar = "path" ,
32
+ type = str ,
33
+ help = "Path to the last checkpoint to continue training." ,
34
+ )
21
35
parser .add_argument (
22
36
"--plot_history" ,
23
37
dest = "plot_history" ,
@@ -48,16 +62,22 @@ def main():
48
62
train_files_path , input_shape , validation_split = 0
49
63
)
50
64
51
- eval_x , eval_y , t1 , t2 , = load_dataset (
65
+ eval_x , eval_y , _ , _ , = load_dataset (
52
66
eval_files_path , input_shape , validation_split = 0
53
67
)
54
68
55
69
backbone = backbones .SegmentationVanillaUnet (input_shape )
56
- # optimizer = SGD(lr=0.0001, momentum=0.9, decay=0.0 )
70
+ # optimizer = SGD(lr=0.0001, momentum=0.9, decay=0.0005 )
57
71
optimizer = Adam (lr = 0.00001 )
58
72
59
73
train_engine = TrainEngine (
60
- input_shape , backbone .model , optimizer , loss = "binary_crossentropy"
74
+ input_shape ,
75
+ backbone .model ,
76
+ optimizer ,
77
+ loss = "binary_crossentropy" ,
78
+ checkpoints_save_path = args .checkpoints_save_path ,
79
+ checkpoint_save_period = 100 ,
80
+ last_checkpoint_path = args .last_checkpoint_path ,
61
81
)
62
82
63
83
loss , acc , val_loss , val_acc = train_engine .train (
@@ -71,7 +91,7 @@ def main():
71
91
)
72
92
if plot_history :
73
93
plots .plot_history (loss , acc , val_loss , val_acc )
74
- for idx in range (len (eval_x [:3 ])):
94
+ for idx in range (len (eval_x [:2 ])):
75
95
predictions = train_engine .model .predict (
76
96
np .array ([eval_x [idx ]], dtype = np .float32 ), batch_size = 1
77
97
)
0 commit comments