diff --git a/utils.py b/utils.py index 55ba7b9..2895d72 100644 --- a/utils.py +++ b/utils.py @@ -12,3 +12,125 @@ # # ******************************************************************* + +import datetime +import numpy as np +import cv2 + +# ------------------------------------------------------------------- +# Parameters +# ------------------------------------------------------------------- + +CONF_THRESHOLD = 0.5 +NMS_THRESHOLD = 0.4 +IMG_WIDTH = 416 +IMG_HEIGHT = 416 + +# Default colors +COLOR_BLUE = (255, 0, 0) +COLOR_GREEN = (0, 255, 0) +COLOR_RED = (0, 0, 255) +COLOR_WHITE = (255, 255, 255) +COLOR_YELLOW = (0, 255, 255) + + +# ------------------------------------------------------------------- +# Help functions +# ------------------------------------------------------------------- + +# Get the names of the output layers +def get_outputs_names(net): + # Get the names of all the layers in the network + layers_names = net.getLayerNames() + + # Get the names of the output layers, i.e. the layers with unconnected + # outputs + return [layers_names[i[0] - 1] for i in net.getUnconnectedOutLayers()] + + +# Draw the predicted bounding box +def draw_predict(frame, conf, left, top, right, bottom): + # Draw a bounding box. + cv2.rectangle(frame, (left, top), (right, bottom), COLOR_YELLOW, 2) + + text = '{:.2f}'.format(conf) + + # Display the label at the top of the bounding box + label_size, base_line = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) + + top = max(top, label_size[1]) + cv2.putText(frame, text, (left, top - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.4, + COLOR_WHITE, 1) + + +def post_process(frame, outs, conf_threshold, nms_threshold): + frame_height = frame.shape[0] + frame_width = frame.shape[1] + + # Scan through all the bounding boxes output from the network and keep only + # the ones with high confidence scores. Assign the box's class label as the + # class with the highest score. + confidences = [] + boxes = [] + final_boxes = [] + for out in outs: + for detection in out: + scores = detection[5:] + class_id = np.argmax(scores) + confidence = scores[class_id] + if confidence > conf_threshold: + center_x = int(detection[0] * frame_width) + center_y = int(detection[1] * frame_height) + width = int(detection[2] * frame_width) + height = int(detection[3] * frame_height) + left = int(center_x - width / 2) + top = int(center_y - height / 2) + confidences.append(float(confidence)) + boxes.append([left, top, width, height]) + + # Perform non maximum suppression to eliminate redundant + # overlapping boxes with lower confidences. + indices = cv2.dnn.NMSBoxes(boxes, confidences, conf_threshold, + nms_threshold) + + for i in indices: + i = i[0] + box = boxes[i] + left = box[0] + top = box[1] + width = box[2] + height = box[3] + final_boxes.append(box) + draw_predict(frame, confidences[i], left, top, left + width, + top + height) + return final_boxes + + +class FPS: + def __init__(self): + # store the start time, end time, and total number of frames + # that were examined between the start and end intervals + self._start = None + self._end = None + self._num_frames = 0 + + def start(self): + self._start = datetime.datetime.now() + return self + + def stop(self): + self._end = datetime.datetime.now() + + def update(self): + # increment the total number of frames examined during the + # start and end intervals + self._num_frames += 1 + + def elapsed(self): + # return the total number of seconds between the start and + # end interval + return (self._end - self._start).total_seconds() + + def fps(self): + # compute the (approximate) frames per second + return self._num_frames / self.elapsed() diff --git a/yoloface.py b/yoloface.py index 89f48f9..daa0e6a 100644 --- a/yoloface.py +++ b/yoloface.py @@ -12,3 +12,150 @@ # # ******************************************************************* +# Usage example: python yoloface.py --image samples/outside_000001.jpg \ +# --output-dir outputs/ +# python yoloface.py --video samples/subway.mp4 \ +# --output-dir outputs/ +# python yoloface.py --src 1 --output-dir outputs/ + + +import argparse +import sys +import os + +from utils import * + +##################################################################### +parser = argparse.ArgumentParser() +parser.add_argument('--model-cfg', type=str, default='./cfg/yolov3-face.cfg', + help='path to config file') +parser.add_argument('--model-weights', type=str, + default='./model-weights/yolov3-wider_16000.weights', + help='path to weights of model') +parser.add_argument('--image', type=str, default='', + help='path to image file') +parser.add_argument('--video', type=str, default='', + help='path to video file') +parser.add_argument('--src', type=int, default=0, + help='source of the camera') +parser.add_argument('--output-dir', type=str, default='outputs/', + help='path to the output directory') +args = parser.parse_args() + +##################################################################### +# print the arguments +print('----- info -----') +print('[i] The config file: ', args.model_cfg) +print('[i] The weights of model file: ', args.model_weights) +print('[i] Path to image file: ', args.image) +print('[i] Path to video file: ', args.video) +print('###########################################################\n') + +# check outputs directory +if not os.path.exists(args.output_dir): + print('==> Creating the {} directory...'.format(args.output_dir)) + os.makedirs(args.output_dir) +else: + print('==> Skipping create the {} directory...'.format(args.output_dir)) + +# Give the configuration and weight files for the model and load the network +# using them. +net = cv2.dnn.readNetFromDarknet(args.model_cfg, args.model_weights) +net.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV) +net.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU) + + +def _main(): + wind_name = 'face detection using YOLOv3' + cv2.namedWindow(wind_name, cv2.WINDOW_NORMAL) + + output_file = '' + + if args.image: + if not os.path.isfile(args.image): + print("[!] ==> Input image file {} doesn't exist".format(args.image)) + sys.exit(1) + cap = cv2.VideoCapture(args.image) + output_file = args.image[:-4].rsplit('/')[-1] + '_yoloface.jpg' + elif args.video: + if not os.path.isfile(args.video): + print("[!] ==> Input video file {} doesn't exist".format(args.video)) + sys.exit(1) + cap = cv2.VideoCapture(args.video) + output_file = args.video[:-4].rsplit('/')[-1] + '_yoloface.avi' + else: + # Get data from the camera + cap = cv2.VideoCapture(args.src) + + # Get the video writer initialized to save the output video + if not args.image: + video_writer = cv2.VideoWriter(os.path.join(args.output_dir, output_file), + cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), + cap.get(cv2.CAP_PROP_FPS), ( + round(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), + round(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))) + + while True: + + has_frame, frame = cap.read() + + # Stop the program if reached end of video + if not has_frame: + print('[i] ==> Done processing!!!') + print('[i] ==> Output file is stored at', os.path.join(args.output_dir, output_file)) + cv2.waitKey(1000) + break + + fps = FPS().start() + + # Create a 4D blob from a frame. + blob = cv2.dnn.blobFromImage(frame, 1 / 255, (IMG_WIDTH, IMG_HEIGHT), + [0, 0, 0], 1, crop=False) + + # Sets the input to the network + net.setInput(blob) + + # Runs the forward pass to get output of the output layers + outs = net.forward(get_outputs_names(net)) + + # Remove the bounding boxes with low confidence + faces = post_process(frame, outs, CONF_THRESHOLD, NMS_THRESHOLD) + print('[i] ==> # detected faces: {}'.format(len(faces))) + print('#' * 60) + + # update fps counter + fps.update() + fps.stop() + + # initialize the set of information we'll displaying on the frame + info = [ + ('number of faces detected', '{}(s)'.format(len(faces))) + ] + + for (i, (txt, val)) in enumerate(info): + text = '{}: {}'.format(txt, val) + cv2.putText(frame, text, (10, (i * 20) + 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, COLOR_RED, 2) + + # Save the output video to file + if args.image: + cv2.imwrite(os.path.join(args.output_dir, output_file), frame.astype(np.uint8)) + else: + video_writer.write(frame.astype(np.uint8)) + + cv2.imshow(wind_name, frame) + + key = cv2.waitKey(1) + if key == 27 or key == ord('q'): + print('[i] ==> Interrupted by user!') + break + + cap.release() + cv2.destroyAllWindows() + + print('==> All done!') + print('***********************************************************') + + +if __name__ == '__main__': + _main()