-
Notifications
You must be signed in to change notification settings - Fork 706
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ORT][TRT] support FaceFusion (#445)
- Loading branch information
1 parent
4252d27
commit c70e1ff
Showing
55 changed files
with
3,313 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
// | ||
// Created by wangzijian on 11/1/24. | ||
// | ||
#include "lite/lite.h" | ||
#include "lite/trt/cv/trt_face_68landmarks_mt.h" | ||
|
||
static void test_default() | ||
{ | ||
#ifdef ENABLE_ONNXRUNTIME | ||
std::string onnx_path = "/home/lite.ai.toolkit/examples/hub/onnx/cv/2dfan4.onnx"; | ||
std::string test_img_path = "/home/lite.ai.toolkit/examples/lite/resources/test_lite_facefusion_pipeline_source.jpg"; | ||
|
||
// 1. Test Default Engine ONNXRuntime | ||
lite::cv::faceid::Face_68Landmarks *face68Landmarks = new lite::cv::faceid::Face_68Landmarks(onnx_path); | ||
|
||
lite::types::BoundingBoxType<float, float> bbox; | ||
bbox.x1 = 487; | ||
bbox.y1 = 236; | ||
bbox.x2 = 784; | ||
bbox.y2 = 624; | ||
|
||
cv::Mat img_bgr = cv::imread(test_img_path); | ||
std::vector<cv::Point2f> face_landmark_5of68; | ||
face68Landmarks->detect(img_bgr, bbox, face_landmark_5of68); | ||
|
||
std::cout<<"face id detect done!"<<std::endl; | ||
|
||
delete face68Landmarks; | ||
#endif | ||
} | ||
|
||
|
||
|
||
|
||
static void test_tensorrt() | ||
{ | ||
#ifdef ENABLE_TENSORRT | ||
std::string engine_path = "/home/lite.ai.toolkit/examples/hub/trt/2dfan4_fp16.engine"; | ||
std::string test_img_path = "/home/lite.ai.toolkit/1.jpg"; | ||
|
||
// 1. Test TensorRT Engine | ||
lite::trt::cv::faceid::FaceFusionFace68Landmarks *face68Landmarks = new lite::trt::cv::faceid::FaceFusionFace68Landmarks(engine_path); | ||
lite::types::BoundingBoxType<float, float> bbox; | ||
bbox.x1 = 487; | ||
bbox.y1 = 236; | ||
bbox.x2 = 784; | ||
bbox.y2 = 624; | ||
|
||
cv::Mat img_bgr = cv::imread(test_img_path); | ||
std::vector<cv::Point2f> face_landmark_5of68; | ||
face68Landmarks->detect(img_bgr, bbox, face_landmark_5of68); | ||
|
||
std::cout<<"face id detect done!"<<std::endl; | ||
|
||
delete face68Landmarks; | ||
#endif | ||
} | ||
|
||
|
||
static void test_tensorrt_mt() | ||
{ | ||
#ifdef ENABLE_TENSORRT | ||
std::string engine_path = "/home/lite.ai.toolkit/examples/hub/trt/2dfan4_fp16.engine"; | ||
std::string test_img_path = "/home/lite.ai.toolkit/1.jpg"; | ||
|
||
// 1. Test TensorRT Engine | ||
// lite::trt::cv::faceid::FaceFusionFace68Landmarks *face68Landmarks = new lite::trt::cv::faceid::FaceFusionFace68Landmarks(engine_path); | ||
trt_face_68landmarks_mt *face68Landmarks = new trt_face_68landmarks_mt(engine_path,4); | ||
|
||
lite::types::BoundingBoxType<float, float> bbox; | ||
|
||
bbox.x1 = 487; | ||
bbox.y1 = 236; | ||
bbox.x2 = 784; | ||
bbox.y2 = 624; | ||
|
||
cv::Mat img_bgr = cv::imread(test_img_path); | ||
std::vector<cv::Point2f> face_landmark_5of68; | ||
face68Landmarks->detect_async(img_bgr, bbox, face_landmark_5of68); | ||
|
||
cv::Mat img_bgr2 = cv::imread(test_img_path); | ||
std::vector<cv::Point2f> face_landmark_5of682; | ||
face68Landmarks->detect_async(img_bgr, bbox, face_landmark_5of682); | ||
|
||
cv::Mat img_bgr3 = cv::imread(test_img_path); | ||
std::vector<cv::Point2f> face_landmark_5of683; | ||
face68Landmarks->detect_async(img_bgr, bbox, face_landmark_5of683); | ||
|
||
|
||
cv::Mat img_bgr4 = cv::imread(test_img_path); | ||
std::vector<cv::Point2f> face_landmark_5of684; | ||
face68Landmarks->detect_async(img_bgr, bbox, face_landmark_5of684); | ||
|
||
face68Landmarks->wait_for_completion(); | ||
|
||
face68Landmarks->shutdown(); | ||
|
||
std::cout<<"face id detect done!"<<std::endl; | ||
|
||
delete face68Landmarks; | ||
#endif | ||
} | ||
|
||
|
||
|
||
int main(__unused int argc, __unused char *argv[]) | ||
{ | ||
// test_tensorrt(); | ||
test_tensorrt_mt(); | ||
// test_default(); | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// | ||
// Created by wangzijian on 11/5/24. | ||
// | ||
#include "lite/lite.h" | ||
|
||
static void test_default() | ||
{ | ||
#ifdef ENABLE_ONNXRUNTIME | ||
std::string onnx_path = "../../../examples/hub/onnx/cv/arcface_w600k_r50.onnx"; | ||
std::string test_img_path = "../../../examples/lite/resources/test_lite_facefusion_pipeline_source.jpg"; | ||
|
||
// 1. Test Default Engine ONNXRuntime | ||
lite::cv::faceid::Face_Recognizer *face_recognizer = new lite::cv::faceid::Face_Recognizer(onnx_path); | ||
|
||
std::vector<cv::Point2f> face_landmark_5 = { | ||
cv::Point2f(568.2485f, 398.9512f), | ||
cv::Point2f(701.7346f, 399.64795f), | ||
cv::Point2f(634.2213f, 482.92694f), | ||
cv::Point2f(583.5656f, 543.10187f), | ||
cv::Point2f(684.52405f, 543.125f) | ||
}; | ||
cv::Mat img_bgr = cv::imread(test_img_path); | ||
|
||
std::vector<float> source_image_embeding; | ||
|
||
face_recognizer->detect(img_bgr,face_landmark_5,source_image_embeding); | ||
|
||
|
||
std::cout<<"face id detect done!"<<std::endl; | ||
|
||
delete face_recognizer; | ||
#endif | ||
} | ||
|
||
int main(__unused int argc, __unused char *argv[]) | ||
{ | ||
test_default(); | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
// | ||
// Created by wangzijian on 11/7/24. | ||
// | ||
#include "lite/lite.h" | ||
|
||
#include "lite/trt/cv/trt_face_restoration_mt.h" | ||
|
||
static void test_default() | ||
{ | ||
#ifdef ENABLE_ONNXRUNTIME | ||
std::string onnx_path = "/home/lite.ai.toolkit/examples/hub/onnx/cv/gfpgan_1.4.onnx"; | ||
std::string test_img_path = "/home/lite.ai.toolkit/trt_result.jpg"; | ||
std::string save_img_path = "/home/lite.ai.toolkit/trt_result_final.jpg"; | ||
|
||
// 1. Test Default Engine ONNXRuntime | ||
lite::cv::face::restoration::GFPGAN *face_restoration = new lite::cv::face::restoration::GFPGAN(onnx_path); | ||
|
||
std::vector<cv::Point2f> face_landmark_5 = { | ||
cv::Point2f(569.092041f, 398.845886f), | ||
cv::Point2f(701.891724f, 399.156677f), | ||
cv::Point2f(634.767212f, 482.927216f), | ||
cv::Point2f(584.270996f, 543.294617f), | ||
cv::Point2f(684.877991f, 543.067078f) | ||
}; | ||
cv::Mat img_bgr = cv::imread(test_img_path); | ||
|
||
face_restoration->detect(img_bgr,face_landmark_5,save_img_path); | ||
|
||
|
||
std::cout<<"face restoration detect done!"<<std::endl; | ||
|
||
delete face_restoration; | ||
#endif | ||
} | ||
|
||
|
||
|
||
|
||
static void test_tensorrt() | ||
{ | ||
#ifdef ENABLE_TENSORRT | ||
std::string engine_path = "/home/lite.ai.toolkit/examples/hub/trt/gfpgan_1.4_fp32.engine"; | ||
std::string test_img_path = "/home/lite.ai.toolkit/trt_result.jpg"; | ||
std::string save_img_path = "/home/lite.ai.toolkit/trt_facerestoration_mt_test111.jpg"; | ||
|
||
// 1. Test Default Engine TensorRT | ||
// lite::trt::cv::face::restoration::TRTGFPGAN *face_restoration_trt = new lite::trt::cv::face::restoration::TRTGFPGAN(engine_path); | ||
|
||
const int num_threads = 4; // 使用4个线程 | ||
auto face_restoration_trt = std::make_unique<trt_face_restoration_mt>(engine_path,4); | ||
|
||
// trt_face_restoration_mt *face_restoration_trt = new trt_face_restoration_mt(engine_path); | ||
|
||
|
||
// 2. 准备测试数据 - 这里假设我们要处理4张相同的图片作为示例 | ||
std::vector<std::string> test_img_paths = { | ||
"/home/lite.ai.toolkit/trt_result.jpg", | ||
"/home/lite.ai.toolkit/trt_result_2.jpg", | ||
"/home/lite.ai.toolkit/trt_result_3.jpg", | ||
"/home/lite.ai.toolkit/trt_result_4.jpg" | ||
}; | ||
|
||
std::vector<std::string> save_img_paths = { | ||
"/home/lite.ai.toolkit/trt_facerestoration_mt_thread1.jpg", | ||
"/home/lite.ai.toolkit/trt_facerestoration_mt_thread2.jpg", | ||
"/home/lite.ai.toolkit/trt_facerestoration_mt_thread3.jpg", | ||
"/home/lite.ai.toolkit/trt_facerestoration_mt_thread4.jpg" | ||
}; | ||
|
||
std::vector<cv::Point2f> face_landmark_5 = { | ||
cv::Point2f(569.092041f, 398.845886f), | ||
cv::Point2f(701.891724f, 399.156677f), | ||
cv::Point2f(634.767212f, 482.927216f), | ||
cv::Point2f(584.270996f, 543.294617f), | ||
cv::Point2f(684.877991f, 543.067078f) | ||
}; | ||
// cv::Mat img_bgr = cv::imread(test_img_path); | ||
// | ||
// face_restoration_trt->detect_async(img_bgr,face_landmark_5,save_img_path); | ||
// | ||
// | ||
// std::cout<<"face restoration detect done!"<<std::endl; | ||
// | ||
// delete face_restoration_trt; | ||
auto start_time = std::chrono::high_resolution_clock::now(); | ||
|
||
for (size_t i=0; i < test_img_paths.size();++i){ | ||
cv::Mat img_bgr = cv::imread(test_img_paths[i]); | ||
if (img_bgr.empty()) { | ||
std::cerr << "Failed to read image: " << test_img_paths[i] << std::endl; | ||
continue; | ||
} | ||
// 异步提交任务 | ||
face_restoration_trt->detect_async(img_bgr, face_landmark_5, save_img_paths[i]); | ||
std::cout << "Submitted task " << i + 1 << " for processing" << std::endl; | ||
} | ||
|
||
// 6. 等待所有任务完成 | ||
std::cout << "Waiting for all tasks to complete..." << std::endl; | ||
face_restoration_trt->wait_for_completion(); | ||
|
||
// 7. 计算和输出总耗时 | ||
auto end_time = std::chrono::high_resolution_clock::now(); | ||
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time); | ||
|
||
std::cout << "All tasks completed!" << std::endl; | ||
std::cout << "Total processing time: " << duration.count() << "ms" << std::endl; | ||
std::cout << "Average time per image: " << duration.count() / test_img_paths.size() << "ms" << std::endl; | ||
|
||
|
||
#endif | ||
} | ||
|
||
int main(__unused int argc, __unused char *argv[]) | ||
{ | ||
// test_default(); | ||
test_tensorrt(); | ||
return 0; | ||
} |
Oops, something went wrong.