Skip to content

Commit

Permalink
Merge pull request #412 from wangzijian1010/main
Browse files Browse the repository at this point in the history
[Feature] Add YOLOv8Face ORT support
  • Loading branch information
DefTruth authored Jul 10, 2024
2 parents cfb0826 + 0ba5991 commit 05fe5a4
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 1 deletion.
Binary file added examples/hub/onnx/cv/yoloface_8n.onnx
Binary file not shown.
1 change: 1 addition & 0 deletions examples/lite/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,5 @@ add_lite_executable(lite_mobile_hair_seg cv)
add_lite_executable(lite_yolov6 cv)
add_lite_executable(lite_face_parsing_bisenet cv)
add_lite_executable(lite_face_parsing_bisenet_dyn cv)
add_lite_executable(lite_yolov8face cv)

35 changes: 35 additions & 0 deletions examples/lite/cv/test_lite_yolov8face.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//
// Created by ai-test1 on 24-7-8.
//

#include "lite/lite.h"

static void test_default()
{
std::string onnx_path = "../../../examples/hub/onnx/cv/yoloface_8n.onnx";
std::string test_img_path = "../../../examples/lite/resources/test_lite_face_detector_2.jpg";
std::string save_img_path = "../../../examples/logs/test_lite_yolov8face.jpg";

// 1. Test Default Engine ONNXRuntime
lite::cv::face::detect::YOLOV8Face *yolov8_face = new lite::cv::face::detect::YOLOV8Face(onnx_path);

std::vector<lite::types::Boxf> detected_boxes;

cv::Mat img_bgr = cv::imread(test_img_path);

yolov8_face->detect(img_bgr, detected_boxes);
lite::utils::draw_boxes_inplace(img_bgr, detected_boxes);

cv::imwrite(save_img_path, img_bgr);

std::cout<<"face detect done!"<<std::endl;

delete yolov8_face;

}

int main(__unused int argc, __unused char *argv[])
{
test_default();
return 0;
}
3 changes: 3 additions & 0 deletions lite/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
#include "lite/ort/cv/yolov6.h"
#include "lite/ort/cv/face_parsing_bisenet.h"
#include "lite/ort/cv/face_parsing_bisenet_dyn.h"
#include "lite/ort/cv/yolofacev8.h"

#endif

Expand Down Expand Up @@ -472,6 +473,7 @@ namespace lite
typedef ortcv::YOLOv6 _ONNXYOLOv6;
typedef ortcv::FaceParsingBiSeNet _ONNXFaceParsingBiSeNet;
typedef ortcv::FaceParsingBiSeNetDyn _ONNXFaceParsingBiSeNetDyn;
typedef ortcv::YoloFaceV8 _ONNXYOLOFaceNet;

// 1. classification
namespace classification
Expand Down Expand Up @@ -528,6 +530,7 @@ namespace lite
typedef _ONNXYOLO5Face YOLO5Face;
typedef _ONNXFaceBoxesV2 FaceBoxesV2;
typedef _ONNXYOLOv5BlazeFace YOLOv5BlazeFace;
typedef _ONNXYOLOFaceNet YOLOV8Face;
}

namespace align
Expand Down
1 change: 1 addition & 0 deletions lite/ort/core/ort_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ namespace ortcv
class LITE_EXPORTS YOLOv6; // [103] * reference: https://github.com/meituan/YOLOv6
class LITE_EXPORTS FaceMesh; // [104] * reference: https://github.com/google/mediapipe/tree/master/mediapipe/graphs/face_mesh
class LITE_EXPORTS IrisLandmarks; // [105] * reference: https://github.com/google/mediapipe/tree/master/mediapipe/graphs/iris_tracking
class LITE_EXPORTS YoloFaceV8; // [106] * reference: https://github.com/derronqi/yolov8-face
}

namespace ortnlp
Expand Down
202 changes: 202 additions & 0 deletions lite/ort/cv/yolofacev8.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
//
// Created by ai-test1 on 24-7-8.
//

#include "yolofacev8.h"
#include "lite/ort/core/ort_utils.h"
#include "lite/utils.h"

using ortcv::YoloFaceV8;

float YoloFaceV8::get_iou(const lite::types::Boxf box1, const lite::types::Boxf box2) {
float x1 = std::max(box1.x1, box2.x1);
float y1 = std::max(box1.y1, box2.y1);
float x2 = std::min(box1.x2, box2.x2);
float y2 = std::min(box1.y2, box2.y2);
float w = std::max(0.f, x2 - x1);
float h = std::max(0.f, y2 - y1);
float over_area = w * h;
if (over_area == 0)
return 0.0;
float union_area = (box1.x2 - box1.x1) * (box1.y2 - box1.y1) + (box2.x2 - box2.x1) * (box2.y2 - box2.y1) - over_area;
return over_area / union_area;
}

std::vector<int> YoloFaceV8::nms(std::vector<lite::types::Boxf> boxes, std::vector<float> confidences, const float nms_thresh) {
sort(confidences.begin(), confidences.end(), [&confidences](size_t index_1, size_t index_2)
{ return confidences[index_1] > confidences[index_2]; });
const int num_box = confidences.size();
std::vector<bool> isSuppressed(num_box, false);
for (int i = 0; i < num_box; ++i)
{
if (isSuppressed[i])
{
continue;
}
for (int j = i + 1; j < num_box; ++j)
{
if (isSuppressed[j])
{
continue;
}

float ovr = this->get_iou(boxes[i], boxes[j]);
if (ovr > nms_thresh)
{
isSuppressed[j] = true;
}
}
}

std::vector<int> keep_inds;
for (int i = 0; i < isSuppressed.size(); i++)
{
if (!isSuppressed[i])
{
keep_inds.emplace_back(i);
}
}
return keep_inds;
}


cv::Mat YoloFaceV8::normalize(cv::Mat srcimg) {
const int height = srcimg.rows;
const int width = srcimg.cols;
cv::Mat temp_image = srcimg.clone();
int input_height = 640;
int input_width = 640;

if (height > input_height || width > input_width)
{
const float scale = std::min((float)input_height / height, (float)input_width / width);
cv::Size new_size = cv::Size(int(width * scale), int(height * scale));
cv::resize(srcimg, temp_image, new_size);
}

ratio_height = (float)height / temp_image.rows;
ratio_width = (float)width / temp_image.cols;

cv::Mat input_img;
cv::copyMakeBorder(temp_image, input_img, 0, input_height - temp_image.rows,
0, input_width - temp_image.cols, cv::BORDER_CONSTANT, 0);

std::vector<cv::Mat> bgrChannels(3);
cv::split(input_img, bgrChannels);
for (int c = 0; c < 3; c++)
{
bgrChannels[c].convertTo(bgrChannels[c], CV_32FC1, 1 / 128.0, -127.5 / 128.0);
}
cv::Mat normalized_image;
cv::merge(bgrChannels,normalized_image);
return normalized_image;

}


Ort::Value YoloFaceV8::transform(const cv::Mat &mat_rs) {

return ortcv::utils::transform::create_tensor(
mat_rs, input_node_dims, memory_info_handler,
input_values_handler, ortcv::utils::transform::CHW);
}


void YoloFaceV8::generate_box(std::vector<Ort::Value> &ort_outputs,
std::vector<lite::types::Boxf> &boxes,
float conf_threshold, float iou_threshold)
{
// 形状是(1, 20, 8400),不考虑第0维batchsize,每一列的长度20,
// 前4个元素是检测框坐标(cx,cy,w,h),第4个元素是置信度,剩下的15个元素是5个关键点坐标x,y和置信度
float *pdata = ort_outputs[0].GetTensorMutableData<float>();
const int num_box = ort_outputs[0].GetTensorTypeAndShapeInfo().GetShape()[2];
std::vector<lite::types::BoundingBoxType<float, float>> bounding_box_raw;
std::vector<float> score_raw;
for (int i = 0; i < num_box; i++)
{
const float score = pdata[4 * num_box + i];
if (score > conf_threshold)
{
// (cx,cy,w,h) to (x,y,w,h) and in origin pic
float x1 = (pdata[i] - 0.5 * pdata[2 * num_box + i]) * ratio_width;
float y1 = (pdata[num_box + i] - 0.5 * pdata[3 * num_box + i]) * ratio_height;
float x2 = (pdata[i] + 0.5 * pdata[2 * num_box + i]) * ratio_width;
float y2 = (pdata[num_box + i] + 0.5 * pdata[3 * num_box + i]) * ratio_height;
// TODO: 坐标的越界检查保护,可以添加一下

// 创建 BoundingBoxType 对象并设置其成员变量
lite::types::BoundingBoxType<float, float> bbox;
bbox.x1 = x1;
bbox.y1 = y1;
bbox.x2 = x2;
bbox.y2 = y2;
bbox.score = score; // 设置置信度
bbox.flag = true;
// 其他成员变量可以保持默认值或根据需要设置
bounding_box_raw.emplace_back(bbox);
score_raw.emplace_back(score);
}
}
std::vector<int> keep_inds = this->nms(bounding_box_raw, score_raw, iou_threshold);
const int keep_num = keep_inds.size();
boxes.clear();
boxes.resize(keep_num);
for (int i = 0; i < keep_num; i++)
{
const int ind = keep_inds[i];
boxes[i] = bounding_box_raw[ind];
}
#if LITEORT_DEBUG
std::cout << "detected num_anchors: " << num_box << "\n";
std::cout << "generate_bboxes num: " << boxes.size() << "\n";
#endif
}


void save_tensor_to_file(const Ort::Value& tensor, const std::string& filename) {
// 获取张量的类型和形状信息
auto type_and_shape_info = tensor.GetTensorTypeAndShapeInfo();
std::vector<int64_t> shape = type_and_shape_info.GetShape();
size_t element_count = type_and_shape_info.GetElementCount();

ONNXTensorElementDataType type = type_and_shape_info.GetElementType();
if (type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
std::cerr << "Unsupported tensor data type. Only float tensors are supported." << std::endl;
return;
}

const float* pdata = tensor.GetTensorData<float>();

std::ofstream file(filename);
if (!file.is_open()) {
std::cerr << "Could not open file for writing: " << filename << std::endl;
return;
}

for (size_t i = 0; i < element_count; ++i) {
file << pdata[i] << "\n";
}
file.close();
}


void YoloFaceV8::detect(const cv::Mat &mat,std::vector<lite::types::Boxf> &boxes,
float conf_threshold, float iou_threshold) {

if (mat.empty()) return;

cv::Mat mat_rs = this->normalize(mat);

// 1. make input tensor
Ort::Value input_tensor = this->transform(mat_rs);

Ort::RunOptions runOptions;

// 2. inference scores & boxes.
auto output_tensors = ort_session->Run(
runOptions, input_node_names.data(),
&input_tensor, 1, output_node_names.data(), num_outputs
);

this->generate_box(output_tensors, boxes, conf_threshold, iou_threshold);
}
52 changes: 52 additions & 0 deletions lite/ort/cv/yolofacev8.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//
// Created by ai-test1 on 24-7-8.
//

#ifndef LITE_AI_TOOLKIT_YOLOFACEV8_H
#define LITE_AI_TOOLKIT_YOLOFACEV8_H
#include "lite/ort/core/ort_core.h"
#include "algorithm"
#include <vector>

namespace ortcv {

class LITE_EXPORTS YoloFaceV8 : public BasicOrtHandler{
public:
explicit YoloFaceV8(const std::string &_onnx_path, unsigned int _num_threads = 1) :
BasicOrtHandler(_onnx_path, _num_threads)
{};
~YoloFaceV8() override = default;

private:
float mean = -127.5 / 128.0;
float scale = 1 / 128.0;
// const float conf_threshold = 0.5f;
// const float iou_threshold = 0.4f;
float ratio_width;
float ratio_height;


private:
// need override's function
Ort::Value transform(const cv::Mat &mat_rs) override;


float get_iou(const lite::types::Boxf box1, const lite::types::Boxf box2);

std::vector<int> nms(std::vector<lite::types::Boxf> boxes, std::vector<float> confidences, const float nms_thresh);

cv::Mat normalize(cv::Mat srcImg);

void generate_box(std::vector<Ort::Value> &ort_outputs, std::vector<lite::types::Boxf> &boxes,
float conf_threshold = 0.25f, float iou_threshold = 0.45f);


public:

void detect(const cv::Mat &mat, std::vector<lite::types::Boxf> &boxes,
float conf_threshold = 0.25f, float iou_threshold = 0.45f);
};
}


#endif //LITE_AI_TOOLKIT_YOLOFACEV8_H
3 changes: 2 additions & 1 deletion lite/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ namespace lite {
{ types::__assert_type<value_type, score_type>(); }
}; // End BoundingBox.

// specific alias.

// specific alias.
template class LITE_EXPORTS BoundingBoxType<int, float>;
template class LITE_EXPORTS BoundingBoxType<float, float>;
template class LITE_EXPORTS BoundingBoxType<double, double>;
Expand Down

0 comments on commit 05fe5a4

Please sign in to comment.