From 841eaf23eee25c8939218af02974c90ac701982e Mon Sep 17 00:00:00 2001 From: wangzijian <107230768+wangzijian1010@users.noreply.github.com> Date: Thu, 25 Jul 2024 18:16:54 +0800 Subject: [PATCH] [TensorRT] support YOLOv8 with TensorRT backend (#419) * test yolox code,TODO clean code * add tensorrt yolox test code * update var name to trt engine model path * update transform func,to avoid pointers not being freed * update transform func,to avoid pointers not being freed * update infer code to use new transform func * add yolov8 in namespace * implement yolov8 code and test code * add yolov8 compile * update var name in test code * add TRTYoloV8 reference * update code to fix bug --- examples/lite/CMakeLists.txt | 1 + examples/lite/cv/test_lite_yolov8.cpp | 43 ++++++++ lite/models.h | 3 + lite/trt/core/trt_core.h | 1 + lite/trt/cv/trt_yolov8.cpp | 137 ++++++++++++++++++++++++++ lite/trt/cv/trt_yolov8.h | 64 ++++++++++++ 6 files changed, 249 insertions(+) create mode 100644 examples/lite/cv/test_lite_yolov8.cpp create mode 100644 lite/trt/cv/trt_yolov8.cpp create mode 100644 lite/trt/cv/trt_yolov8.h diff --git a/examples/lite/CMakeLists.txt b/examples/lite/CMakeLists.txt index 946462a7..a176ae56 100644 --- a/examples/lite/CMakeLists.txt +++ b/examples/lite/CMakeLists.txt @@ -101,4 +101,5 @@ add_lite_executable(lite_yolov6 cv) add_lite_executable(lite_face_parsing_bisenet cv) add_lite_executable(lite_face_parsing_bisenet_dyn cv) add_lite_executable(lite_yolov8face cv) +add_lite_executable(lite_yolov8 cv) diff --git a/examples/lite/cv/test_lite_yolov8.cpp b/examples/lite/cv/test_lite_yolov8.cpp new file mode 100644 index 00000000..ea88b88a --- /dev/null +++ b/examples/lite/cv/test_lite_yolov8.cpp @@ -0,0 +1,43 @@ +// +// Created by ai-test1 on 24-7-8. +// + +#include "lite/lite.h" + + + +static void test_tensorrt() +{ +#ifdef ENABLE_TENSORRT + std::string engine_path = "../../..//examples/hub/trt/yolov8n_fp32.engine"; + std::string test_img_path = "../../..//examples/lite/resources/test_lite_yolov5_1.jpg"; + std::string save_img_path = "../../..//examples/logs/test_lite_yolov8_trt_1.jpg"; + + lite::trt::cv::detection::YOLOV8 *yolov8 = new lite::trt::cv::detection::YOLOV8(engine_path); + + cv::Mat test_image = cv::imread(test_img_path); + + std::vector detected_boxes; + + yolov8->detect(test_image,detected_boxes,0.5f,0.4f); + + std::cout<<"trt yolov8 detect done!"< &input, std::vector &output, + float iou_threshold, unsigned int topk, unsigned int nms_type) +{ + if (nms_type == NMS::BLEND) lite::utils::blending_nms(input, output, iou_threshold, topk); + else if (nms_type == NMS::OFFSET) lite::utils::offset_nms(input, output, iou_threshold, topk); + else lite::utils::hard_nms(input, output, iou_threshold, topk); +} + +void TRTYoloV8::generate_bboxes(std::vector &bbox_collection, float* output, float score_threshold, + int img_height, int img_width) { + auto pred_dims = output_node_dims[0]; + const unsigned int num_anchors = pred_dims[2]; // 8400 + const unsigned int num_classes = pred_dims[1] - 4; // 80 + + float x_factor = float(img_width) / input_node_dims[3]; + float y_factor = float(img_height) / input_node_dims[2]; + + bbox_collection.clear(); + unsigned int count = 0; + + for (unsigned int i = 0; i < num_anchors; ++i) { + + std::vector class_scores(num_classes); + for (unsigned int j = 0; j < num_classes; ++j) { + class_scores[j] = output[(4 + j) * num_anchors + i]; + } + + auto max_it = std::max_element(class_scores.begin(), class_scores.end()); + float max_cls_conf = *max_it; + unsigned int label = std::distance(class_scores.begin(), max_it); + + float conf = max_cls_conf; + if (conf < score_threshold) continue; + + float cx = output[0 * num_anchors + i]; + float cy = output[1 * num_anchors + i]; + float w = output[2 * num_anchors + i]; + float h = output[3 * num_anchors + i]; + + float x1 = (cx - w / 2.f) * x_factor; + float y1 = (cy - h / 2.f) * y_factor; + + w = w * x_factor; + h = h * y_factor; + + float x2 = x1 + w ; + float y2 = y1 + h; + + types::Boxf box; + box.x1 = std::max(0.f, x1); + box.y1 = std::max(0.f, y1); + box.x2 = std::min(x2, (float) img_width - 1.f); + box.y2 = std::min(y2, (float) img_height - 1.f); + box.score = conf; + box.label = label; + box.label_text = class_names[label]; + box.flag = true; + bbox_collection.push_back(box); + + count += 1; + if (count > max_nms) + break; + } + +#if LITETRT_DEBUG + std::cout << "detected num_anchors: " << num_anchors << "\n"; + std::cout << "generate_bboxes num: " << bbox_collection.size() << "\n"; +#endif +} + +void TRTYoloV8::preprocess(cv::Mat &input_image) { + + // Convert color space from BGR to RGB + cv::cvtColor(input_image, input_image, cv::COLOR_BGR2RGB); + + // Resize image + cv::resize(input_image, input_image, cv::Size(input_node_dims[2], input_node_dims[3]), 0, 0, cv::INTER_LINEAR); + + // Normalize image + input_image.convertTo(input_image, CV_32F, scale_val, mean_val); +} + + +void TRTYoloV8::detect(const cv::Mat &mat, std::vector &detected_boxes, float score_threshold, + float iou_threshold, unsigned int topk, unsigned int nms_type) { + + if (mat.empty()) return; + int img_height = static_cast(mat.rows); + int img_width = static_cast(mat.cols); + + // resize & unscale + cv::Mat mat_rs = mat.clone(); + + preprocess(mat_rs); + + //1. make the input + std::vector input; + trtcv::utils::transform::create_tensor(mat_rs,input,input_node_dims,trtcv::utils::transform::CHW); + + //2. infer + cudaMemcpyAsync(buffers[0], input.data(), input_node_dims[0] * input_node_dims[1] * input_node_dims[2] * input_node_dims[3] * sizeof(float), + cudaMemcpyHostToDevice, stream); + + cudaStreamSynchronize(stream); + + bool status = trt_context->enqueueV3(stream); + cudaStreamSynchronize(stream); + if (!status){ + std::cerr << "Failed to infer by TensorRT." << std::endl; + return; + } + + cudaStreamSynchronize(stream); + + // get the first output dim + auto pred_dims = output_node_dims[0]; + + std::vector output(pred_dims[0] * pred_dims[1] * pred_dims[2]); + + cudaMemcpyAsync(output.data(), buffers[1], pred_dims[0] * pred_dims[1] * pred_dims[2] * sizeof(float), + cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + + std::vector bbox_collection; + generate_bboxes(bbox_collection,output.data(),score_threshold,img_height,img_width); + nms(bbox_collection, detected_boxes, iou_threshold, topk, nms_type); + +} + diff --git a/lite/trt/cv/trt_yolov8.h b/lite/trt/cv/trt_yolov8.h new file mode 100644 index 00000000..7d80ea1f --- /dev/null +++ b/lite/trt/cv/trt_yolov8.h @@ -0,0 +1,64 @@ +// +// Created by wangzijian on 7/24/24. +// + +#ifndef LITE_AI_TOOLKIT_TRT_YOLOV8_H +#define LITE_AI_TOOLKIT_TRT_YOLOV8_H + + +#include "lite/trt/core/trt_core.h" +#include "lite/utils.h" +#include "lite/trt/core/trt_utils.h" +#include + +namespace trtcv { + class LITE_EXPORTS TRTYoloV8 : public BasicTRTHandler { + public: + explicit TRTYoloV8(const std::string &_trt_model_path, unsigned int _num_threads = 1) : + BasicTRTHandler(_trt_model_path, _num_threads) {}; + + ~TRTYoloV8() override = default; + + + private: + static constexpr const float mean_val = 0.f; + static constexpr const float scale_val = 1.0 / 255.f; + const char *class_names[80] = { + "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", + "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", + "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", + "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", + "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", + "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", + "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", + "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", + "scissors", "teddy bear", "hair drier", "toothbrush" + }; + enum NMS { + HARD = 0, BLEND = 1, OFFSET = 2 + }; + static constexpr const unsigned int max_nms = 30000; + + private: + + void preprocess(cv::Mat &input_image); + + void normalized(cv::Mat &input_image); + + void generate_bboxes( + std::vector &bbox_collection, + float *output, + float score_threshold, int img_height, + int img_width); // rescale & exclude + + void nms(std::vector &input, std::vector &output, + float iou_threshold, unsigned int topk, unsigned int nms_type); + + public: + void detect(const cv::Mat &mat, std::vector &detected_boxes, + float score_threshold = 0.25f, float iou_threshold = 0.45f, + unsigned int topk = 100, unsigned int nms_type = NMS::OFFSET); + }; + +#endif //LITE_AI_TOOLKIT_TRT_YOLOV8_H +} \ No newline at end of file