forked from oyxhust/ssd-text_detection
-
Notifications
You must be signed in to change notification settings - Fork 4
/
demo_savefig.py
107 lines (100 loc) · 4.59 KB
/
demo_savefig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import tools.find_mxnet
import mxnet as mx
import os
import importlib
import sys
from detect.detector_savefig import Detector
import re
CLASSES = ('text', )
#'bottle', 'bus', 'car', 'cat', 'chair',
#'cow', 'diningtable', 'dog', 'horse',
#'motorbike', 'person', 'pottedplant',
#'sheep', 'sofa', 'train', 'tvmonitor')
def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx,
nms_thresh=0.5, force_nms=True):
"""
wrapper for initialize a detector
Parameters:
----------
net : str
test network name
prefix : str
load model prefix
epoch : int
load model epoch
data_shape : int
resize image shape
mean_pixels : tuple (float, float, float)
mean pixel values (R, G, B)
ctx : mx.ctx
running context, mx.cpu() or mx.gpu(?)
force_nms : bool
force suppress different categories
"""
sys.path.append(os.path.join(os.getcwd(), 'symbol'))
net = importlib.import_module("symbol_" + net) \
.get_symbol(len(CLASSES), nms_thresh, force_nms)
detector = Detector(net, prefix + "_" + str(data_shape), epoch, \
data_shape, mean_pixels, ctx=ctx)
return detector
def parse_args():
parser = argparse.ArgumentParser(description='Single-shot detection network demo')
parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced',
choices=['vgg16_reduced'], help='which network to use')
parser.add_argument('--images', dest='images', type=str, default='./data/demo',
help='run demo with images, use comma(without space) to seperate multiple images')
parser.add_argument('--dir', dest='dir', nargs='?',
help='demo image directory, optional', type=str)
parser.add_argument('--ext', dest='extension', help='image extension, optional',
type=str, nargs='?')
parser.add_argument('--epoch', dest='epoch', help='epoch of trained model',
default=0, type=int)
parser.add_argument('--prefix', dest='prefix', help='trained model prefix',
default=os.path.join(os.getcwd(), 'model', 'ssd'), type=str)
parser.add_argument('--cpu', dest='cpu', help='(override GPU) use CPU to detect',
action='store_true', default=False)
parser.add_argument('--gpu', dest='gpu_id', type=int, default=0,
help='GPU device id to detect with')
parser.add_argument('--data-shape', dest='data_shape', type=int, default=300,
help='set image shape')
parser.add_argument('--mean-r', dest='mean_r', type=float, default=123,
help='red mean value')
parser.add_argument('--mean-g', dest='mean_g', type=float, default=117,
help='green mean value')
parser.add_argument('--mean-b', dest='mean_b', type=float, default=104,
help='blue mean value')
parser.add_argument('--thresh', dest='thresh', type=float, default=0.5,
help='object visualize score threshold, default 0.6')
parser.add_argument('--nms', dest='nms_thresh', type=float, default=0.5,
help='non-maximum suppression threshold, default 0.5')
parser.add_argument('--force', dest='force_nms', type=bool, default=True,
help='force non-maximum suppression on different class')
parser.add_argument('--timer', dest='show_timer', type=bool, default=True,
help='show detection time')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if args.cpu:
ctx = mx.cpu()
else:
ctx = mx.gpu(args.gpu_id)
# parse image list
#print(args.images)
#print(args.images.split('/')[-2])
image_list = []
list_dirs = os.walk(args.images)
for root, _, files in list_dirs:
for f in files:
if re.match(r'.+\.jpg', f):
image_list.append(os.path.join(root, f))
#image_list = [i.strip() for i in args.images.split(',')]
assert len(image_list) > 0, "No valid image specified to detect"
detector = get_detector(args.network, args.prefix, args.epoch,
args.data_shape,
(args.mean_r, args.mean_g, args.mean_b),
ctx, args.nms_thresh, args.force_nms)
# run detection
detector.detect_and_visualize(args.images, image_list, args.dir, args.extension,
CLASSES, args.thresh, args.show_timer)