-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathrun.py
132 lines (119 loc) · 4.8 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 3 12:17:38 2018
@author: github.com/GustavZ
"""
import os
import tarfile
from six.moves import urllib
import numpy as np
import tensorflow as tf
import yaml
import cv2
from stuff.helper import FPS2, WebcamVideoStream
from skimage import measure
## LOAD CONFIG PARAMS ##
if (os.path.isfile('config.yml')):
with open("config.yml", 'r') as ymlfile:
cfg = yaml.load(ymlfile)
else:
with open("config.sample.yml", 'r') as ymlfile:
cfg = yaml.load(ymlfile)
VIDEO_INPUT = cfg['video_input']
FPS_INTERVAL = cfg['fps_interval']
ALPHA = cfg['alpha']
MODEL_NAME = cfg['model_name']
MODEL_PATH = cfg['model_path']
DOWNLOAD_BASE = cfg['download_base']
BBOX = cfg['bbox']
MINAREA = cfg['minArea']
# Hardcoded COCO_VOC Labels
LABEL_NAMES = np.asarray([
'', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv'])
def create_colormap(seg_map):
"""
Takes A 2D array storing the segmentation labels.
Returns A 2D array where each element is the color indexed
by the corresponding element in the input label to the PASCAL color map.
"""
colormap = np.zeros((256, 3), dtype=int)
ind = np.arange(256, dtype=int)
for shift in reversed(range(8)):
for channel in range(3):
colormap[:, channel] |= ((ind >> channel) & 1) << shift
ind >>= 3
return colormap[seg_map]
# Download Model from TF-deeplab's Model Zoo
def download_model():
model_file = MODEL_NAME + '.tar.gz'
if not os.path.isfile(MODEL_PATH):
print('> Model not found. Downloading it now.')
opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + model_file, model_file)
tar_file = tarfile.open(model_file)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'frozen_inference_graph.pb' in file_name:
tar_file.extract(file, os.getcwd() + '/models/')
os.remove(os.getcwd() + '/' + model_file)
else:
print('> Model found. Proceed.')
# Visualize Text on OpenCV Image
def vis_text(image,string,pos):
cv2.putText(image,string,(pos),
cv2.FONT_HERSHEY_SIMPLEX, 0.75, (77, 255, 9), 2)
# Load frozen Model
def load_frozenmodel():
print('> Loading frozen model into memory')
detection_graph = tf.Graph()
with detection_graph.as_default():
seg_graph_def = tf.GraphDef()
with tf.gfile.GFile(MODEL_PATH, 'rb') as fid:
serialized_graph = fid.read()
seg_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(seg_graph_def, name='')
return detection_graph
def segmentation(detection_graph,label_names):
# fixed input sizes as model needs resize either way
vs = WebcamVideoStream(VIDEO_INPUT,640,480).start()
resize_ratio = 1.0 * 513 / max(vs.real_width,vs.real_height)
target_size = (int(resize_ratio * vs.real_width), int(resize_ratio * vs.real_height))
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth=True
fps = FPS2(FPS_INTERVAL).start()
print("> Starting Segmentaion")
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
while vs.isActive():
image = cv2.resize(vs.read(),target_size)
batch_seg_map = sess.run('SemanticPredictions:0',
feed_dict={'ImageTensor:0': [cv2.cvtColor(image, cv2.COLOR_BGR2RGB)]})
# visualization
seg_map = batch_seg_map[0]
seg_image = create_colormap(seg_map).astype(np.uint8)
cv2.addWeighted(seg_image,ALPHA,image,1-ALPHA,0,image)
vis_text(image,"fps: {}".format(fps.fps_local()),(10,30))
# boxes (ymin, xmin, ymax, xmax)
if BBOX:
map_labeled = measure.label(seg_map, connectivity=1)
for region in measure.regionprops(map_labeled):
if region.area > MINAREA:
box = region.bbox
p1 = (box[1], box[0])
p2 = (box[3], box[2])
cv2.rectangle(image, p1, p2, (77,255,9), 2)
vis_text(image,label_names[seg_map[tuple(region.coords[0])]],(p1[0],p1[1]-10))
cv2.imshow('segmentation',image)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
fps.update()
fps.stop()
vs.stop()
cv2.destroyAllWindows()
if __name__ == '__main__':
download_model()
graph = load_frozenmodel()
segmentation(graph, LABEL_NAMES)