Skip to content
This repository was archived by the owner on Sep 20, 2024. It is now read-only.

Commit e9bbf59

Browse files
committed
Load checkpoint
1 parent e91e062 commit e9bbf59

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

src/train.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,20 @@ def main():
1818
parser.add_argument(
1919
"test_files_path", metavar="path", help="Path to the test files",
2020
)
21+
parser.add_argument(
22+
"--checkpoints_save_path",
23+
dest="checkpoints_save_path",
24+
metavar="path",
25+
type=str,
26+
help="Path where the checkpoints should be saved.",
27+
)
28+
parser.add_argument(
29+
"--last_checkpoint_path",
30+
dest="last_checkpoint_path",
31+
metavar="path",
32+
type=str,
33+
help="Path to the last checkpoint to continue training.",
34+
)
2135
parser.add_argument(
2236
"--plot_history",
2337
dest="plot_history",
@@ -48,16 +62,22 @@ def main():
4862
train_files_path, input_shape, validation_split=0
4963
)
5064

51-
eval_x, eval_y, t1, t2, = load_dataset(
65+
eval_x, eval_y, _, _, = load_dataset(
5266
eval_files_path, input_shape, validation_split=0
5367
)
5468

5569
backbone = backbones.SegmentationVanillaUnet(input_shape)
56-
# optimizer = SGD(lr=0.0001, momentum=0.9, decay=0.0)
70+
# optimizer = SGD(lr=0.0001, momentum=0.9, decay=0.0005)
5771
optimizer = Adam(lr=0.00001)
5872

5973
train_engine = TrainEngine(
60-
input_shape, backbone.model, optimizer, loss="binary_crossentropy"
74+
input_shape,
75+
backbone.model,
76+
optimizer,
77+
loss="binary_crossentropy",
78+
checkpoints_save_path=args.checkpoints_save_path,
79+
checkpoint_save_period=100,
80+
last_checkpoint_path=args.last_checkpoint_path,
6181
)
6282

6383
loss, acc, val_loss, val_acc = train_engine.train(
@@ -71,7 +91,7 @@ def main():
7191
)
7292
if plot_history:
7393
plots.plot_history(loss, acc, val_loss, val_acc)
74-
for idx in range(len(eval_x[:3])):
94+
for idx in range(len(eval_x[:2])):
7595
predictions = train_engine.model.predict(
7696
np.array([eval_x[idx]], dtype=np.float32), batch_size=1
7797
)

src/train_engine.py

+1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def _train(
113113
initial_epoch=self._initial_epoch,
114114
shuffle=True,
115115
validation_data=(eval_x, eval_y),
116+
callbacks=self.callbacks,
116117
verbose=0,
117118
)
118119
self._initial_epoch = epochs

0 commit comments

Comments
 (0)