Skip to content

Commit

Permalink
add the main code of yoloface
Browse files Browse the repository at this point in the history
  • Loading branch information
sthanhng committed Oct 12, 2018
1 parent 838dbc8 commit c03dd78
Show file tree
Hide file tree
Showing 2 changed files with 269 additions and 0 deletions.
122 changes: 122 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,125 @@
#
# *******************************************************************


import datetime
import numpy as np
import cv2

# -------------------------------------------------------------------
# Parameters
# -------------------------------------------------------------------

CONF_THRESHOLD = 0.5
NMS_THRESHOLD = 0.4
IMG_WIDTH = 416
IMG_HEIGHT = 416

# Default colors
COLOR_BLUE = (255, 0, 0)
COLOR_GREEN = (0, 255, 0)
COLOR_RED = (0, 0, 255)
COLOR_WHITE = (255, 255, 255)
COLOR_YELLOW = (0, 255, 255)


# -------------------------------------------------------------------
# Help functions
# -------------------------------------------------------------------

# Get the names of the output layers
def get_outputs_names(net):
# Get the names of all the layers in the network
layers_names = net.getLayerNames()

# Get the names of the output layers, i.e. the layers with unconnected
# outputs
return [layers_names[i[0] - 1] for i in net.getUnconnectedOutLayers()]


# Draw the predicted bounding box
def draw_predict(frame, conf, left, top, right, bottom):
# Draw a bounding box.
cv2.rectangle(frame, (left, top), (right, bottom), COLOR_YELLOW, 2)

text = '{:.2f}'.format(conf)

# Display the label at the top of the bounding box
label_size, base_line = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)

top = max(top, label_size[1])
cv2.putText(frame, text, (left, top - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.4,
COLOR_WHITE, 1)


def post_process(frame, outs, conf_threshold, nms_threshold):
frame_height = frame.shape[0]
frame_width = frame.shape[1]

# Scan through all the bounding boxes output from the network and keep only
# the ones with high confidence scores. Assign the box's class label as the
# class with the highest score.
confidences = []
boxes = []
final_boxes = []
for out in outs:
for detection in out:
scores = detection[5:]
class_id = np.argmax(scores)
confidence = scores[class_id]
if confidence > conf_threshold:
center_x = int(detection[0] * frame_width)
center_y = int(detection[1] * frame_height)
width = int(detection[2] * frame_width)
height = int(detection[3] * frame_height)
left = int(center_x - width / 2)
top = int(center_y - height / 2)
confidences.append(float(confidence))
boxes.append([left, top, width, height])

# Perform non maximum suppression to eliminate redundant
# overlapping boxes with lower confidences.
indices = cv2.dnn.NMSBoxes(boxes, confidences, conf_threshold,
nms_threshold)

for i in indices:
i = i[0]
box = boxes[i]
left = box[0]
top = box[1]
width = box[2]
height = box[3]
final_boxes.append(box)
draw_predict(frame, confidences[i], left, top, left + width,
top + height)
return final_boxes


class FPS:
def __init__(self):
# store the start time, end time, and total number of frames
# that were examined between the start and end intervals
self._start = None
self._end = None
self._num_frames = 0

def start(self):
self._start = datetime.datetime.now()
return self

def stop(self):
self._end = datetime.datetime.now()

def update(self):
# increment the total number of frames examined during the
# start and end intervals
self._num_frames += 1

def elapsed(self):
# return the total number of seconds between the start and
# end interval
return (self._end - self._start).total_seconds()

def fps(self):
# compute the (approximate) frames per second
return self._num_frames / self.elapsed()
147 changes: 147 additions & 0 deletions yoloface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,150 @@
#
# *******************************************************************

# Usage example: python yoloface.py --image samples/outside_000001.jpg \
# --output-dir outputs/
# python yoloface.py --video samples/subway.mp4 \
# --output-dir outputs/
# python yoloface.py --src 1 --output-dir outputs/


import argparse
import sys
import os

from utils import *

#####################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--model-cfg', type=str, default='./cfg/yolov3-face.cfg',
help='path to config file')
parser.add_argument('--model-weights', type=str,
default='./model-weights/yolov3-wider_16000.weights',
help='path to weights of model')
parser.add_argument('--image', type=str, default='',
help='path to image file')
parser.add_argument('--video', type=str, default='',
help='path to video file')
parser.add_argument('--src', type=int, default=0,
help='source of the camera')
parser.add_argument('--output-dir', type=str, default='outputs/',
help='path to the output directory')
args = parser.parse_args()

#####################################################################
# print the arguments
print('----- info -----')
print('[i] The config file: ', args.model_cfg)
print('[i] The weights of model file: ', args.model_weights)
print('[i] Path to image file: ', args.image)
print('[i] Path to video file: ', args.video)
print('###########################################################\n')

# check outputs directory
if not os.path.exists(args.output_dir):
print('==> Creating the {} directory...'.format(args.output_dir))
os.makedirs(args.output_dir)
else:
print('==> Skipping create the {} directory...'.format(args.output_dir))

# Give the configuration and weight files for the model and load the network
# using them.
net = cv2.dnn.readNetFromDarknet(args.model_cfg, args.model_weights)
net.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV)
net.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU)


def _main():
wind_name = 'face detection using YOLOv3'
cv2.namedWindow(wind_name, cv2.WINDOW_NORMAL)

output_file = ''

if args.image:
if not os.path.isfile(args.image):
print("[!] ==> Input image file {} doesn't exist".format(args.image))
sys.exit(1)
cap = cv2.VideoCapture(args.image)
output_file = args.image[:-4].rsplit('/')[-1] + '_yoloface.jpg'
elif args.video:
if not os.path.isfile(args.video):
print("[!] ==> Input video file {} doesn't exist".format(args.video))
sys.exit(1)
cap = cv2.VideoCapture(args.video)
output_file = args.video[:-4].rsplit('/')[-1] + '_yoloface.avi'
else:
# Get data from the camera
cap = cv2.VideoCapture(args.src)

# Get the video writer initialized to save the output video
if not args.image:
video_writer = cv2.VideoWriter(os.path.join(args.output_dir, output_file),
cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
cap.get(cv2.CAP_PROP_FPS), (
round(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
round(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))

while True:

has_frame, frame = cap.read()

# Stop the program if reached end of video
if not has_frame:
print('[i] ==> Done processing!!!')
print('[i] ==> Output file is stored at', os.path.join(args.output_dir, output_file))
cv2.waitKey(1000)
break

fps = FPS().start()

# Create a 4D blob from a frame.
blob = cv2.dnn.blobFromImage(frame, 1 / 255, (IMG_WIDTH, IMG_HEIGHT),
[0, 0, 0], 1, crop=False)

# Sets the input to the network
net.setInput(blob)

# Runs the forward pass to get output of the output layers
outs = net.forward(get_outputs_names(net))

# Remove the bounding boxes with low confidence
faces = post_process(frame, outs, CONF_THRESHOLD, NMS_THRESHOLD)
print('[i] ==> # detected faces: {}'.format(len(faces)))
print('#' * 60)

# update fps counter
fps.update()
fps.stop()

# initialize the set of information we'll displaying on the frame
info = [
('number of faces detected', '{}(s)'.format(len(faces)))
]

for (i, (txt, val)) in enumerate(info):
text = '{}: {}'.format(txt, val)
cv2.putText(frame, text, (10, (i * 20) + 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, COLOR_RED, 2)

# Save the output video to file
if args.image:
cv2.imwrite(os.path.join(args.output_dir, output_file), frame.astype(np.uint8))
else:
video_writer.write(frame.astype(np.uint8))

cv2.imshow(wind_name, frame)

key = cv2.waitKey(1)
if key == 27 or key == ord('q'):
print('[i] ==> Interrupted by user!')
break

cap.release()
cv2.destroyAllWindows()

print('==> All done!')
print('***********************************************************')


if __name__ == '__main__':
_main()

0 comments on commit c03dd78

Please sign in to comment.