forked from matpalm/bnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_sample_training_pipeline.sh
executable file
·65 lines (53 loc) · 1.69 KB
/
run_sample_training_pipeline.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env bash
# this script represents the end to end pipeline from
# labelling -> training -> various testing / predict methods
# it runs NOWHERE near long enough during training to get a decent
# result and is included as a smoke test
set -x
rm -rf sample_data/labels/ ckpts/r12 tb/r12 sample_predictions.db
set -e
# run labelling UI
./label_ui.py \
--image-dir sample_data/training/ \
--label-db sample_data/labels.db \
--width 768 --height 1024
# materialise label database into bitmaps
./materialise_label_db.py \
--label-db sample_data/labels.db \
--directory sample_data/labels/ \
--width 768 --height 1024
# generate some 256x236 sample patches of the data.
./data.py \
--image-dir sample_data/training/ \
--label-dir sample_data/labels/ \
--rotate --distort \
--patch-width-height 256
# train for a bit using 256 square patches for training and
# full resolution for test.
./train.py \
--run r12 \
--steps 2 \
--train-steps 2 \
--batch-size 4 \
--train-image-dir sample_data/training/ \
--test-image-dir sample_data/test/ \
--label-dir sample_data/labels/ \
--pos-weight 5 \
--patch-width-height 256 \
--width 768 --height 1024
# run inference against unlabelled data
./predict.py \
--run r12 \
--image-dir sample_data/unlabelled \
--output-label-db sample_predictions.db \
--export-pngs predictions
# check loss statistics against training data
./test.py \
--run r12 \
--image-dir sample_data/training/ \
--label-db sample_data/labels.db
# check loss statistics against labelled test data
./test.py \
--run r12 \
--image-dir sample_data/test/ \
--label-db sample_data/labels.db