Skip to content

Commit

Permalink
refactor code of GPU-based
Browse files Browse the repository at this point in the history
  • Loading branch information
sthanhng committed Jan 14, 2019
1 parent 12978b8 commit 3eb897a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 44 deletions.
16 changes: 0 additions & 16 deletions yolo/utils.py

This file was deleted.

56 changes: 30 additions & 26 deletions YOLO.py → yolo/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#
# Face detection using the YOLOv3 algorithm
#
# Description : YOLO.py
# Description : yolo.py
# Contains methods of YOLO
#
# *******************************************************************
Expand All @@ -17,7 +17,6 @@
import cv2

from yolo.model import eval
from yolo.utils import letterbox_image

from keras import backend as K
from keras.models import load_model
Expand All @@ -42,7 +41,6 @@ def _get_class(self):
with open(classes_path) as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
print(class_names)
return class_names

def _get_anchors(self):
Expand All @@ -57,7 +55,7 @@ def _generate(self):
assert model_path.endswith(
'.h5'), 'Keras model or weights must be a .h5 file'

# Load model, or construct model and load weights
# load model, or construct model and load weights
num_anchors = len(self.anchors)
num_classes = len(self.class_names)
try:
Expand All @@ -70,24 +68,23 @@ def _generate(self):
num_anchors / len(self.yolo_model.output) * (
num_classes + 5), \
'Mismatch between model and given anchor and class sizes'

print(
'[i] ==> {} model, anchors, and classes loaded.'.format(model_path))
'*** {} model, anchors, and classes loaded.'.format(model_path))

# Generate colors for drawing bounding boxes
# generate colors for drawing bounding boxes
hsv_tuples = [(x / len(self.class_names), 1., 1.)
for x in range(len(self.class_names))]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(
map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
self.colors))

# Shuffle colors to decorrelate adjacent classes.
# shuffle colors to decorrelate adjacent classes.
np.random.seed(102)
np.random.shuffle(self.colors)
np.random.seed(None)

# Generate output tensor targets for filtered bounding boxes.
# generate output tensor targets for filtered bounding boxes.
self.input_image_shape = K.placeholder(shape=(2,))
boxes, scores, classes = eval(self.yolo_model.output, self.anchors,
len(self.class_names),
Expand All @@ -98,7 +95,6 @@ def _generate(self):

def detect_image(self, image):
start_time = timer()

if self.model_image_size != (None, None):
assert self.model_image_size[
0] % 32 == 0, 'Multiples of 32 required'
Expand All @@ -111,28 +107,24 @@ def detect_image(self, image):
image.height - (image.height % 32))
boxed_image = letterbox_image(image, new_image_size)
image_data = np.array(boxed_image, dtype='float32')

print(image_data.shape)
image_data /= 255.
# Add batch dimension
# add batch dimension
image_data = np.expand_dims(image_data, 0)

out_boxes, out_scores, out_classes = self.sess.run(
[self.boxes, self.scores, self.classes],
feed_dict={
self.yolo_model.input: image_data,
self.input_image_shape: [image.size[1], image.size[0]],
K.learning_phase(): 0
})

print('[i] ==> Found {} face(s) for this image'.format(len(out_boxes)))
print('*** Found {} face(s) for this image'.format(len(out_boxes)))
thickness = (image.size[0] + image.size[1]) // 400

for i, c in reversed(list(enumerate(out_classes))):
predicted_class = self.class_names[c]
box = out_boxes[i]
score = out_scores[i]

text = '{} {:.2f}'.format(predicted_class, score)
draw = ImageDraw.Draw(image)

Expand All @@ -151,21 +143,36 @@ def detect_image(self, image):
del draw

end_time = timer()
print('[i] ==> Processing time: {:.2f}ms'.format((end_time -
print('*** Processing time: {:.2f}ms'.format((end_time -
start_time) * 1000))
return image, out_boxes

def close_session(self):
self.sess.close()


def letterbox_image(image, size):
'''Resize image with unchanged aspect ratio using padding'''

img_width, img_height = image.size
w, h = size
scale = min(w / img_width, h / img_height)
nw = int(img_width * scale)
nh = int(img_height * scale)

image = image.resize((nw, nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128, 128, 128))
new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
return new_image


def detect_img(yolo):
while True:
img = input('[i] ==> Input image filename: ')
img = input('*** Input image filename: ')
try:
image = Image.open(img)
except:
print('[!] ==> Open Error! Try again!')
print('*** Open Error! Try again!')
continue
else:
res_image, _ = yolo.detect_image(image)
Expand All @@ -179,16 +186,15 @@ def detect_video(model, video_path=None, output=None):
vid = cv2.VideoCapture(0)
else:
vid = cv2.VideoCapture(video_path)

if not vid.isOpened():
raise IOError("Couldn't open webcam or video")

# The video format and fps
# the video format and fps
# video_fourcc = int(vid.get(cv2.CAP_PROP_FOURCC))
video_fourcc = cv2.VideoWriter_fourcc('M', 'G', 'P', 'G')
video_fps = vid.get(cv2.CAP_PROP_FPS)

# The size of the frames to write
# the size of the frames to write
video_size = (int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT)))
isOutput = True if output != "" else False
Expand Down Expand Up @@ -218,12 +224,11 @@ def detect_video(model, video_path=None, output=None):
fps = curr_fps
curr_fps = 0

# Initialize the set of information we'll displaying on the frame
# initialize the set of information we'll displaying on the frame
info = [
('FPS', '{}'.format(fps)),
('Faces detected', '{}'.format(len(faces)))
]

cv2.rectangle(result, (5, 5), (120, 50), (0, 0, 0), cv2.FILLED)

for (i, (txt, val)) in enumerate(info):
Expand All @@ -239,10 +244,9 @@ def detect_video(model, video_path=None, output=None):
break
else:
break

vid.release()
out.release()
cv2.destroyAllWindows()

# Close the session
# close the session
model.close_session()
3 changes: 1 addition & 2 deletions yoloface_gpu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse

from PIL import Image
from YOLO import YOLO, detect_video, detect_img
from yolo.yolo import YOLO, detect_video, detect_img


#####################################################################
Expand Down

0 comments on commit 3eb897a

Please sign in to comment.