From 674674efdabd10732db0a8a7a2d94ee7f979bf78 Mon Sep 17 00:00:00 2001 From: peijuema <78143172+mpj1234@users.noreply.github.com> Date: Fri, 12 Jul 2024 11:43:09 +0800 Subject: [PATCH] Trt10 (#1552) * The v5-cls model supports TensorRT10 * The v5-cls model supports TensorRT10 Python API * add YOLOv5-cls readme * pre-commit and modify trtx download branch * pre-commit --- yolov5/yolov5_trt10/CMakeLists.txt | 39 ++ yolov5/yolov5_trt10/README.md | 72 ++++ yolov5/yolov5_trt10/plugin/yololayer.cu | 297 +++++++++++++++ yolov5/yolov5_trt10/plugin/yololayer.h | 108 ++++++ yolov5/yolov5_trt10/src/calibrator.cpp | 99 +++++ yolov5/yolov5_trt10/src/calibrator.h | 39 ++ yolov5/yolov5_trt10/src/config.h | 54 +++ yolov5/yolov5_trt10/src/cuda_utils.h | 17 + yolov5/yolov5_trt10/src/logging.h | 456 ++++++++++++++++++++++++ yolov5/yolov5_trt10/src/macros.h | 29 ++ yolov5/yolov5_trt10/src/model.cpp | 331 +++++++++++++++++ yolov5/yolov5_trt10/src/model.h | 8 + yolov5/yolov5_trt10/src/postprocess.cpp | 195 ++++++++++ yolov5/yolov5_trt10/src/postprocess.h | 18 + yolov5/yolov5_trt10/src/preprocess.cu | 146 ++++++++ yolov5/yolov5_trt10/src/preprocess.h | 12 + yolov5/yolov5_trt10/src/types.h | 16 + yolov5/yolov5_trt10/src/utils.h | 68 ++++ yolov5/yolov5_trt10/yolov5_cls.cpp | 307 ++++++++++++++++ yolov5/yolov5_trt10/yolov5_cls_trt.py | 261 ++++++++++++++ 20 files changed, 2572 insertions(+) create mode 100644 yolov5/yolov5_trt10/CMakeLists.txt create mode 100644 yolov5/yolov5_trt10/README.md create mode 100644 yolov5/yolov5_trt10/plugin/yololayer.cu create mode 100644 yolov5/yolov5_trt10/plugin/yololayer.h create mode 100644 yolov5/yolov5_trt10/src/calibrator.cpp create mode 100644 yolov5/yolov5_trt10/src/calibrator.h create mode 100644 yolov5/yolov5_trt10/src/config.h create mode 100644 yolov5/yolov5_trt10/src/cuda_utils.h create mode 100644 yolov5/yolov5_trt10/src/logging.h create mode 100644 yolov5/yolov5_trt10/src/macros.h create mode 100644 yolov5/yolov5_trt10/src/model.cpp create mode 100644 yolov5/yolov5_trt10/src/model.h create mode 100644 yolov5/yolov5_trt10/src/postprocess.cpp create mode 100644 yolov5/yolov5_trt10/src/postprocess.h create mode 100644 yolov5/yolov5_trt10/src/preprocess.cu create mode 100644 yolov5/yolov5_trt10/src/preprocess.h create mode 100644 yolov5/yolov5_trt10/src/types.h create mode 100644 yolov5/yolov5_trt10/src/utils.h create mode 100644 yolov5/yolov5_trt10/yolov5_cls.cpp create mode 100644 yolov5/yolov5_trt10/yolov5_cls_trt.py diff --git a/yolov5/yolov5_trt10/CMakeLists.txt b/yolov5/yolov5_trt10/CMakeLists.txt new file mode 100644 index 00000000..c3505ede --- /dev/null +++ b/yolov5/yolov5_trt10/CMakeLists.txt @@ -0,0 +1,39 @@ +cmake_minimum_required(VERSION 3.10) + +project(yolov5) + +add_definitions(-std=c++11) +add_definitions(-DAPI_EXPORTS) +option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_BUILD_TYPE Debug) + +# TODO(Call for PR): make cmake compatible with Windows +set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc) +enable_language(CUDA) + +# include and link dirs of cuda and tensorrt, you need adapt them if yours are different +# cuda +include_directories(/usr/local/cuda/include) +link_directories(/usr/local/cuda/lib64) + +# tensorrt +# TODO(Call for PR): make TRT path configurable from command line +include_directories(/workspace/shared/TensorRT-10.2.0.19/include/) +link_directories(/workspace/shared/TensorRT-10.2.0.19/lib/) + +include_directories(${PROJECT_SOURCE_DIR}/src/) +include_directories(${PROJECT_SOURCE_DIR}/plugin/) +file(GLOB_RECURSE SRCS ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/*.cu) +file(GLOB_RECURSE PLUGIN_SRCS ${PROJECT_SOURCE_DIR}/plugin/*.cu) + +add_library(myplugins SHARED ${PLUGIN_SRCS}) +target_link_libraries(myplugins nvinfer cudart) + +find_package(OpenCV) +include_directories(${OpenCV_INCLUDE_DIRS}) + +add_executable(yolov5_cls yolov5_cls.cpp ${SRCS}) +target_link_libraries(yolov5_cls nvinfer) +target_link_libraries(yolov5_cls cudart) +target_link_libraries(yolov5_cls ${OpenCV_LIBS}) diff --git a/yolov5/yolov5_trt10/README.md b/yolov5/yolov5_trt10/README.md new file mode 100644 index 00000000..7a5cc5e0 --- /dev/null +++ b/yolov5/yolov5_trt10/README.md @@ -0,0 +1,72 @@ +## Introduce + +Yolov5 model supports TensorRT-10. + +## Environment + +CUDA: 11.8 +CUDNN: 8.9.1.23 +TensorRT: TensorRT-10.2.0.19 + +## Support + +* [x] YOLOv5-cls support FP32/FP16/INT8 and Python/C++ API + +## Config + +* Choose the YOLOv5 sub-model n/s/m/l/x/n6/s6/m6/l6/x6 from command line arguments. +* Other configs please check [src/config.h](src/config.h) + +## Build and Run + +1. generate .wts from pytorch with .pt, or download .wts from model zoo + +```shell +git clone -b v7.0 https://github.com/ultralytics/yolov5.git +git clone -b trt10 https://github.com/wang-xinyu/tensorrtx.git +cd yolov5/ +wget https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5n.pt +cp [PATH-TO-TENSORRTX]/yolov5/gen_wts.py . +python gen_wts.py -w yolov5n.pt -o yolov5n.wts +# A file 'yolov5n.wts' will be generated. +``` + +2. build tensorrtx/yolov5/yolov5_trt10 and run + +#### Classification + +```shell +cd [PATH-TO-TENSORRTX]/yolov5/yolov5_trt10 +# Update kNumClass in src/config.h if your model is trained on custom dataset +mkdir build +cd build +cp [PATH-TO-ultralytics-yolov5]/yolov5s.wts . +cmake .. +make + +# Download ImageNet labels +wget https://github.com/joannzhang00/ImageNet-dataset-classes-labels/blob/main/imagenet_classes.txt + +# Build and serialize TensorRT engine +./yolov5_cls -s yolov5n-cls.wts yolov5n-cls.engine [n/s/m/l/x] + +# Run inference +./yolov5_cls -d yolov5n-cls.engine ../../images +# The results are displayed in the console +``` + +3. Optional, load and run the tensorrt model in Python +```shell +// Install python-tensorrt, pycuda, etc. +// Ensure the yolov5n-cls.engine +python yolov5_cls_trt.py +``` + +## INT8 Quantization +1. Prepare calibration images, you can randomly select 1000s images from your train set. For coco, you can also download my calibration images `coco_calib` from [GoogleDrive](https://drive.google.com/drive/folders/1s7jE9DtOngZMzJC1uL307J2MiaGwdRSI?usp=sharing) or [BaiduPan](https://pan.baidu.com/s/1GOm_-JobpyLMAqZWCDUhKg) pwd: a9wh +2. unzip it in yolov5_trt10/build +3. set the macro `USE_INT8` in src/config.h and make again +4. serialize the model and test + +## More Information +See the readme in [home page.](https://github.com/wang-xinyu/tensorrtx) diff --git a/yolov5/yolov5_trt10/plugin/yololayer.cu b/yolov5/yolov5_trt10/plugin/yololayer.cu new file mode 100644 index 00000000..afd31133 --- /dev/null +++ b/yolov5/yolov5_trt10/plugin/yololayer.cu @@ -0,0 +1,297 @@ +#include "cuda_utils.h" +#include "yololayer.h" + +#include +#include +#include + +namespace Tn { +template +void write(char*& buffer, const T& val) { + *reinterpret_cast(buffer) = val; + buffer += sizeof(T); +} + +template +void read(const char*& buffer, T& val) { + val = *reinterpret_cast(buffer); + buffer += sizeof(T); +} +} // namespace Tn + +namespace nvinfer1 { +YoloLayerPlugin::YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation, + const std::vector& vYoloKernel) { + mClassCount = classCount; + mYoloV5NetWidth = netWidth; + mYoloV5NetHeight = netHeight; + mMaxOutObject = maxOut; + is_segmentation_ = is_segmentation; + mYoloKernel = vYoloKernel; + mKernelCount = vYoloKernel.size(); + + CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*))); + size_t AnchorLen = sizeof(float) * kNumAnchor * 2; + for (int ii = 0; ii < mKernelCount; ii++) { + CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen)); + const auto& yolo = mYoloKernel[ii]; + CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice)); + } +} + +YoloLayerPlugin::~YoloLayerPlugin() { + for (int ii = 0; ii < mKernelCount; ii++) { + CUDA_CHECK(cudaFree(mAnchor[ii])); + } + CUDA_CHECK(cudaFreeHost(mAnchor)); +} + +// create the plugin at runtime from a byte stream +YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length) { + using namespace Tn; + const char *d = reinterpret_cast(data), *a = d; + read(d, mClassCount); + read(d, mThreadCount); + read(d, mKernelCount); + read(d, mYoloV5NetWidth); + read(d, mYoloV5NetHeight); + read(d, mMaxOutObject); + read(d, is_segmentation_); + mYoloKernel.resize(mKernelCount); + auto kernelSize = mKernelCount * sizeof(YoloKernel); + memcpy(mYoloKernel.data(), d, kernelSize); + d += kernelSize; + CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*))); + size_t AnchorLen = sizeof(float) * kNumAnchor * 2; + for (int ii = 0; ii < mKernelCount; ii++) { + CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen)); + const auto& yolo = mYoloKernel[ii]; + CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice)); + } + assert(d == a + length); +} + +void YoloLayerPlugin::serialize(void* buffer) const TRT_NOEXCEPT { + using namespace Tn; + char *d = static_cast(buffer), *a = d; + write(d, mClassCount); + write(d, mThreadCount); + write(d, mKernelCount); + write(d, mYoloV5NetWidth); + write(d, mYoloV5NetHeight); + write(d, mMaxOutObject); + write(d, is_segmentation_); + auto kernelSize = mKernelCount * sizeof(YoloKernel); + memcpy(d, mYoloKernel.data(), kernelSize); + d += kernelSize; + + assert(d == a + getSerializationSize()); +} + +size_t YoloLayerPlugin::getSerializationSize() const TRT_NOEXCEPT { + size_t s = sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount); + s += sizeof(YoloKernel) * mYoloKernel.size(); + s += sizeof(mYoloV5NetWidth) + sizeof(mYoloV5NetHeight); + s += sizeof(mMaxOutObject) + sizeof(is_segmentation_); + return s; +} + +int YoloLayerPlugin::initialize() TRT_NOEXCEPT { + return 0; +} + +Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT { + //output the result to channel + int totalsize = mMaxOutObject * sizeof(Detection) / sizeof(float); + return Dims3(totalsize + 1, 1, 1); +} + +// Set plugin namespace +void YoloLayerPlugin::setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT { + mPluginNamespace = pluginNamespace; +} + +const char* YoloLayerPlugin::getPluginNamespace() const TRT_NOEXCEPT { + return mPluginNamespace; +} + +// Return the DataType of the plugin output at the requested index +DataType YoloLayerPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT { + return DataType::kFLOAT; +} + +// Return true if output tensor is broadcast across a batch. +bool YoloLayerPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, + int nbInputs) const TRT_NOEXCEPT { + return false; +} + +// Return true if plugin can use input that is broadcast across batch without replication. +bool YoloLayerPlugin::canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT { + return false; +} + +void YoloLayerPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, + int nbOutput) TRT_NOEXCEPT {} + +// Attach the plugin object to an execution context and grant the plugin the access to some context resource. +void YoloLayerPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, + IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {} + +// Detach the plugin object from its execution context. +void YoloLayerPlugin::detachFromContext() TRT_NOEXCEPT {} + +const char* YoloLayerPlugin::getPluginType() const TRT_NOEXCEPT { + return "YoloLayer_TRT"; +} + +const char* YoloLayerPlugin::getPluginVersion() const TRT_NOEXCEPT { + return "1"; +} + +void YoloLayerPlugin::destroy() TRT_NOEXCEPT { + delete this; +} + +// Clone the plugin +IPluginV2IOExt* YoloLayerPlugin::clone() const TRT_NOEXCEPT { + YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, + is_segmentation_, mYoloKernel); + p->setPluginNamespace(mPluginNamespace); + return p; +} + +__device__ float Logist(float data) { + return 1.0f / (1.0f + expf(-data)); +}; + +__global__ void CalDetection(const float* input, float* output, int noElements, const int netwidth, const int netheight, + int maxoutobject, int yoloWidth, int yoloHeight, const float anchors[kNumAnchor * 2], + int classes, int outputElem, bool is_segmentation) { + + int idx = threadIdx.x + blockDim.x * blockIdx.x; + if (idx >= noElements) + return; + + int total_grid = yoloWidth * yoloHeight; + int bnIdx = idx / total_grid; + idx = idx - total_grid * bnIdx; + int info_len_i = 5 + classes; + if (is_segmentation) + info_len_i += 32; + const float* curInput = input + bnIdx * (info_len_i * total_grid * kNumAnchor); + + for (int k = 0; k < kNumAnchor; ++k) { + float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]); + if (box_prob < kIgnoreThresh) + continue; + int class_id = 0; + float max_cls_prob = 0.0; + for (int i = 5; i < 5 + classes; ++i) { + float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]); + if (p > max_cls_prob) { + max_cls_prob = p; + class_id = i - 5; + } + } + float* res_count = output + bnIdx * outputElem; + int count = (int)atomicAdd(res_count, 1); + if (count >= maxoutobject) + return; + char* data = (char*)res_count + sizeof(float) + count * sizeof(Detection); + Detection* det = (Detection*)(data); + + int row = idx / yoloWidth; + int col = idx % yoloWidth; + + det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 0 * total_grid])) * + netwidth / yoloWidth; + det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 1 * total_grid])) * + netheight / yoloHeight; + + det->bbox[2] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 2 * total_grid]); + det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2 * k]; + det->bbox[3] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 3 * total_grid]); + det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2 * k + 1]; + det->conf = box_prob * max_cls_prob; + det->class_id = class_id; + + for (int i = 0; is_segmentation && i < 32; i++) { + det->mask[i] = curInput[idx + k * info_len_i * total_grid + (i + 5 + classes) * total_grid]; + } + } +} + +void YoloLayerPlugin::forwardGpu(const float* const* inputs, float* output, cudaStream_t stream, int batchSize) { + int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float); + for (int idx = 0; idx < batchSize; ++idx) { + CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream)); + } + int numElem = 0; + for (unsigned int i = 0; i < mYoloKernel.size(); ++i) { + const auto& yolo = mYoloKernel[i]; + numElem = yolo.width * yolo.height * batchSize; + if (numElem < mThreadCount) + mThreadCount = numElem; + + CalDetection<<<(numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream>>>( + inputs[i], output, numElem, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, yolo.width, yolo.height, + (float*)mAnchor[i], mClassCount, outputElem, is_segmentation_); + } +} + +int YoloLayerPlugin::enqueue(int batchSize, const void* const* inputs, void* TRT_CONST_ENQUEUE* outputs, + void* workspace, cudaStream_t stream) TRT_NOEXCEPT { + forwardGpu((const float* const*)inputs, (float*)outputs[0], stream, batchSize); + return 0; +} + +PluginFieldCollection YoloPluginCreator::mFC{}; +std::vector YoloPluginCreator::mPluginAttributes; + +YoloPluginCreator::YoloPluginCreator() { + mPluginAttributes.clear(); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char* YoloPluginCreator::getPluginName() const TRT_NOEXCEPT { + return "YoloLayer_TRT"; +} + +const char* YoloPluginCreator::getPluginVersion() const TRT_NOEXCEPT { + return "1"; +} + +const PluginFieldCollection* YoloPluginCreator::getFieldNames() TRT_NOEXCEPT { + return &mFC; +} + +IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT { + assert(fc->nbFields == 2); + assert(strcmp(fc->fields[0].name, "netinfo") == 0); + assert(strcmp(fc->fields[1].name, "kernels") == 0); + int* p_netinfo = (int*)(fc->fields[0].data); + int class_count = p_netinfo[0]; + int input_w = p_netinfo[1]; + int input_h = p_netinfo[2]; + int max_output_object_count = p_netinfo[3]; + bool is_segmentation = (bool)p_netinfo[4]; + std::vector kernels(fc->fields[1].length); + memcpy(&kernels[0], fc->fields[1].data, kernels.size() * sizeof(YoloKernel)); + YoloLayerPlugin* obj = + new YoloLayerPlugin(class_count, input_w, input_h, max_output_object_count, is_segmentation, kernels); + obj->setPluginNamespace(mNamespace.c_str()); + return obj; +} + +IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, + size_t serialLength) TRT_NOEXCEPT { + // This object will be deleted when the network is destroyed, which will + // call YoloLayerPlugin::destroy() + YoloLayerPlugin* obj = new YoloLayerPlugin(serialData, serialLength); + obj->setPluginNamespace(mNamespace.c_str()); + return obj; +} +} // namespace nvinfer1 diff --git a/yolov5/yolov5_trt10/plugin/yololayer.h b/yolov5/yolov5_trt10/plugin/yololayer.h new file mode 100644 index 00000000..94d88f3b --- /dev/null +++ b/yolov5/yolov5_trt10/plugin/yololayer.h @@ -0,0 +1,108 @@ +#pragma once + +#include "macros.h" +#include "types.h" + +#include +#include + +namespace nvinfer1 { +class API YoloLayerPlugin : public IPluginV2IOExt { + public: + YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation, + const std::vector& vYoloKernel); + YoloLayerPlugin(const void* data, size_t length); + ~YoloLayerPlugin(); + + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + + Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT override; + + int initialize() TRT_NOEXCEPT override; + + virtual void terminate() TRT_NOEXCEPT override{}; + + virtual size_t getWorkspaceSize(int maxBatchSize) const TRT_NOEXCEPT override { return 0; } + + virtual int enqueue(int batchSize, const void* const* inputs, void* TRT_CONST_ENQUEUE* outputs, void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + virtual size_t getSerializationSize() const TRT_NOEXCEPT override; + + virtual void serialize(void* buffer) const TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, + int nbOutputs) const TRT_NOEXCEPT override { + return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT; + } + + const char* getPluginType() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override; + + IPluginV2IOExt* clone() const TRT_NOEXCEPT override; + + void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override; + + const char* getPluginNamespace() const TRT_NOEXCEPT override; + + DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, + int nbInputs) const TRT_NOEXCEPT override; + + bool canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT override; + + void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, + IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override; + + void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, + int nbOutput) TRT_NOEXCEPT override; + + void detachFromContext() TRT_NOEXCEPT override; + + private: + void forwardGpu(const float* const* inputs, float* output, cudaStream_t stream, int batchSize = 1); + int mThreadCount = 256; + const char* mPluginNamespace; + int mKernelCount; + int mClassCount; + int mYoloV5NetWidth; + int mYoloV5NetHeight; + int mMaxOutObject; + bool is_segmentation_; + std::vector mYoloKernel; + void** mAnchor; +}; + +class API YoloPluginCreator : public IPluginCreator { + public: + YoloPluginCreator(); + + ~YoloPluginCreator() override = default; + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + const PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override; + + IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT override; + + IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, + size_t serialLength) TRT_NOEXCEPT override; + + void setPluginNamespace(const char* libNamespace) TRT_NOEXCEPT override { mNamespace = libNamespace; } + + const char* getPluginNamespace() const TRT_NOEXCEPT override { return mNamespace.c_str(); } + + private: + std::string mNamespace; + static PluginFieldCollection mFC; + static std::vector mPluginAttributes; +}; +REGISTER_TENSORRT_PLUGIN(YoloPluginCreator); +}; // namespace nvinfer1 diff --git a/yolov5/yolov5_trt10/src/calibrator.cpp b/yolov5/yolov5_trt10/src/calibrator.cpp new file mode 100644 index 00000000..4047064d --- /dev/null +++ b/yolov5/yolov5_trt10/src/calibrator.cpp @@ -0,0 +1,99 @@ +#include "calibrator.h" +#include "cuda_utils.h" +#include "utils.h" + +#include +#include +#include +#include +#include + +cv::Mat preprocess_img(cv::Mat& img, int input_w, int input_h) { + int w, h, x, y; + float r_w = input_w / (img.cols * 1.0); + float r_h = input_h / (img.rows * 1.0); + if (r_h > r_w) { + w = input_w; + h = r_w * img.rows; + x = 0; + y = (input_h - h) / 2; + } else { + w = r_h * img.cols; + h = input_h; + x = (input_w - w) / 2; + y = 0; + } + cv::Mat re(h, w, CV_8UC3); + cv::resize(img, re, re.size(), 0, 0, cv::INTER_LINEAR); + cv::Mat out(input_h, input_w, CV_8UC3, cv::Scalar(128, 128, 128)); + re.copyTo(out(cv::Rect(x, y, re.cols, re.rows))); + return out; +} + +Int8EntropyCalibrator2::Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, + const char* calib_table_name, const char* input_blob_name, + bool read_cache) + : batchsize_(batchsize), + input_w_(input_w), + input_h_(input_h), + img_idx_(0), + img_dir_(img_dir), + calib_table_name_(calib_table_name), + input_blob_name_(input_blob_name), + read_cache_(read_cache) { + input_count_ = 3 * input_w * input_h * batchsize; + CUDA_CHECK(cudaMalloc(&device_input_, input_count_ * sizeof(float))); + read_files_in_dir(img_dir, img_files_); +} + +Int8EntropyCalibrator2::~Int8EntropyCalibrator2() { + CUDA_CHECK(cudaFree(device_input_)); +} + +int Int8EntropyCalibrator2::getBatchSize() const TRT_NOEXCEPT { + return batchsize_; +} + +bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT { + if (img_idx_ + batchsize_ > (int)img_files_.size()) { + return false; + } + + std::vector input_imgs_; + for (int i = img_idx_; i < img_idx_ + batchsize_; i++) { + std::cout << img_files_[i] << " " << i << std::endl; + cv::Mat temp = cv::imread(img_dir_ + img_files_[i]); + if (temp.empty()) { + std::cerr << "Fatal error: image cannot open!" << std::endl; + return false; + } + cv::Mat pr_img = preprocess_img(temp, input_w_, input_h_); + input_imgs_.push_back(pr_img); + } + img_idx_ += batchsize_; + cv::Mat blob = cv::dnn::blobFromImages(input_imgs_, 1.0 / 255.0, cv::Size(input_w_, input_h_), cv::Scalar(0, 0, 0), + true, false); + + CUDA_CHECK(cudaMemcpy(device_input_, blob.ptr(0), input_count_ * sizeof(float), cudaMemcpyHostToDevice)); + assert(!strcmp(names[0], input_blob_name_)); + bindings[0] = device_input_; + return true; +} + +const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length) TRT_NOEXCEPT { + std::cout << "reading calib cache: " << calib_table_name_ << std::endl; + calib_cache_.clear(); + std::ifstream input(calib_table_name_, std::ios::binary); + input >> std::noskipws; + if (read_cache_ && input.good()) { + std::copy(std::istream_iterator(input), std::istream_iterator(), std::back_inserter(calib_cache_)); + } + length = calib_cache_.size(); + return length ? calib_cache_.data() : nullptr; +} + +void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT { + std::cout << "writing calib cache: " << calib_table_name_ << " size: " << length << std::endl; + std::ofstream output(calib_table_name_, std::ios::binary); + output.write(reinterpret_cast(cache), length); +} diff --git a/yolov5/yolov5_trt10/src/calibrator.h b/yolov5/yolov5_trt10/src/calibrator.h new file mode 100644 index 00000000..c5e54b9a --- /dev/null +++ b/yolov5/yolov5_trt10/src/calibrator.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include +#include "macros.h" + +cv::Mat preprocess_img(cv::Mat& img, int input_w, int input_h); + +//! \class Int8EntropyCalibrator2 +//! +//! \brief Implements Entropy calibrator 2. +//! CalibrationAlgoType is kENTROPY_CALIBRATION_2. +//! +class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 { + public: + Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, + const char* input_blob_name, bool read_cache = true); + + virtual ~Int8EntropyCalibrator2(); + int getBatchSize() const TRT_NOEXCEPT override; + bool getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT override; + const void* readCalibrationCache(size_t& length) TRT_NOEXCEPT override; + void writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT override; + + private: + int batchsize_; + int input_w_; + int input_h_; + int img_idx_; + std::string img_dir_; + std::vector img_files_; + size_t input_count_; + std::string calib_table_name_; + const char* input_blob_name_; + bool read_cache_; + void* device_input_; + std::vector calib_cache_; +}; diff --git a/yolov5/yolov5_trt10/src/config.h b/yolov5/yolov5_trt10/src/config.h new file mode 100644 index 00000000..7c6a097f --- /dev/null +++ b/yolov5/yolov5_trt10/src/config.h @@ -0,0 +1,54 @@ +#pragma once + +/* -------------------------------------------------------- + * These configs are related to tensorrt model, if these are changed, + * please re-compile and re-serialize the tensorrt model. + * --------------------------------------------------------*/ + +// For INT8, you need prepare the calibration dataset, please refer to +// https://github.com/wang-xinyu/tensorrtx/tree/master/yolov5#int8-quantization +#define USE_FP16 // set USE_INT8 or USE_FP16 or USE_FP32 + +// These are used to define input/output tensor names, +// you can set them to whatever you want. +const static char* kInputTensorName = "data"; +const static char* kOutputTensorName = "prob"; + +// Detection model and Segmentation model' number of classes +constexpr static int kNumClass = 80; + +// Classfication model's number of classes +constexpr static int kClsNumClass = 1000; + +constexpr static int kBatchSize = 1; + +// Yolo's input width and height must by divisible by 32 +constexpr static int kInputH = 640; +constexpr static int kInputW = 640; + +// Classfication model's input shape +constexpr static int kClsInputH = 224; +constexpr static int kClsInputW = 224; + +// Maximum number of output bounding boxes from yololayer plugin. +// That is maximum number of output bounding boxes before NMS. +constexpr static int kMaxNumOutputBbox = 1000; + +constexpr static int kNumAnchor = 3; + +// The bboxes whose confidence is lower than kIgnoreThresh will be ignored in yololayer plugin. +constexpr static float kIgnoreThresh = 0.1f; + +/* -------------------------------------------------------- + * These configs are NOT related to tensorrt model, if these are changed, + * please re-compile, but no need to re-serialize the tensorrt model. + * --------------------------------------------------------*/ + +// NMS overlapping thresh and final detection confidence thresh +const static float kNmsThresh = 0.45f; +const static float kConfThresh = 0.5f; + +const static int kGpuId = 0; + +// If your image size is larger than 4096 * 3112, please increase this value +const static int kMaxInputImageSize = 4096 * 3112; diff --git a/yolov5/yolov5_trt10/src/cuda_utils.h b/yolov5/yolov5_trt10/src/cuda_utils.h new file mode 100644 index 00000000..35d50d84 --- /dev/null +++ b/yolov5/yolov5_trt10/src/cuda_utils.h @@ -0,0 +1,17 @@ +#ifndef TRTX_CUDA_UTILS_H_ +#define TRTX_CUDA_UTILS_H_ + +#include + +#ifndef CUDA_CHECK +#define CUDA_CHECK(callstr) \ + { \ + cudaError_t error_code = callstr; \ + if (error_code != cudaSuccess) { \ + std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ + assert(0); \ + } \ + } +#endif // CUDA_CHECK + +#endif // TRTX_CUDA_UTILS_H_ diff --git a/yolov5/yolov5_trt10/src/logging.h b/yolov5/yolov5_trt10/src/logging.h new file mode 100644 index 00000000..3a25d975 --- /dev/null +++ b/yolov5/yolov5_trt10/src/logging.h @@ -0,0 +1,456 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TENSORRT_LOGGING_H +#define TENSORRT_LOGGING_H + +#include +#include +#include +#include +#include +#include +#include +#include "NvInferRuntimeCommon.h" +#include "macros.h" + +using Severity = nvinfer1::ILogger::Severity; + +class LogStreamConsumerBuffer : public std::stringbuf { + public: + LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) + : mOutput(stream), mPrefix(prefix), mShouldLog(shouldLog) {} + + LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) : mOutput(other.mOutput) {} + + ~LogStreamConsumerBuffer() { + // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence + // std::streambuf::pptr() gives a pointer to the current position of the output sequence + // if the pointer to the beginning is not equal to the pointer to the current position, + // call putOutput() to log the output to the stream + if (pbase() != pptr()) { + putOutput(); + } + } + + // synchronizes the stream buffer and returns 0 on success + // synchronizing the stream buffer consists of inserting the buffer contents into the stream, + // resetting the buffer and flushing the stream + virtual int sync() { + putOutput(); + return 0; + } + + void putOutput() { + if (mShouldLog) { + // prepend timestamp + std::time_t timestamp = std::time(nullptr); + tm* tm_local = std::localtime(×tamp); + std::cout << "["; + std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; + std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; + std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; + // std::stringbuf::str() gets the string contents of the buffer + // insert the buffer contents pre-appended by the appropriate prefix into the stream + mOutput << mPrefix << str(); + // set the buffer to empty + str(""); + // flush the stream + mOutput.flush(); + } + } + + void setShouldLog(bool shouldLog) { mShouldLog = shouldLog; } + + private: + std::ostream& mOutput; + std::string mPrefix; + bool mShouldLog; +}; + +//! +//! \class LogStreamConsumerBase +//! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer +//! +class LogStreamConsumerBase { + public: + LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) + : mBuffer(stream, prefix, shouldLog) {} + + protected: + LogStreamConsumerBuffer mBuffer; +}; + +//! +//! \class LogStreamConsumer +//! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. +//! Order of base classes is LogStreamConsumerBase and then std::ostream. +//! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field +//! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. +//! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. +//! Please do not change the order of the parent classes. +//! +class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream { + public: + //! \brief Creates a LogStreamConsumer which logs messages with level severity. + //! Reportable severity determines if the messages are severe enough to be logged. + LogStreamConsumer(Severity reportableSeverity, Severity severity) + : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity), + std::ostream(&mBuffer) // links the stream buffer with the stream + , + mShouldLog(severity <= reportableSeverity), + mSeverity(severity) {} + + LogStreamConsumer(LogStreamConsumer&& other) + : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog), + std::ostream(&mBuffer) // links the stream buffer with the stream + , + mShouldLog(other.mShouldLog), + mSeverity(other.mSeverity) {} + + void setReportableSeverity(Severity reportableSeverity) { + mShouldLog = mSeverity <= reportableSeverity; + mBuffer.setShouldLog(mShouldLog); + } + + private: + static std::ostream& severityOstream(Severity severity) { + return severity >= Severity::kINFO ? std::cout : std::cerr; + } + + static std::string severityPrefix(Severity severity) { + switch (severity) { + case Severity::kINTERNAL_ERROR: + return "[F] "; + case Severity::kERROR: + return "[E] "; + case Severity::kWARNING: + return "[W] "; + case Severity::kINFO: + return "[I] "; + case Severity::kVERBOSE: + return "[V] "; + default: + assert(0); + return ""; + } + } + + bool mShouldLog; + Severity mSeverity; +}; + +//! \class Logger +//! +//! \brief Class which manages logging of TensorRT tools and samples +//! +//! \details This class provides a common interface for TensorRT tools and samples to log information to the console, +//! and supports logging two types of messages: +//! +//! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) +//! - Test pass/fail messages +//! +//! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is +//! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. +//! +//! In the future, this class could be extended to support dumping test results to a file in some standard format +//! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). +//! +//! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger +//! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT +//! library and messages coming from the sample. +//! +//! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the +//! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger +//! object. + +class Logger : public nvinfer1::ILogger { + public: + Logger(Severity severity = Severity::kWARNING) : mReportableSeverity(severity) {} + + //! + //! \enum TestResult + //! \brief Represents the state of a given test + //! + enum class TestResult { + kRUNNING, //!< The test is running + kPASSED, //!< The test passed + kFAILED, //!< The test failed + kWAIVED //!< The test was waived + }; + + //! + //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger + //! \return The nvinfer1::ILogger associated with this Logger + //! + //! TODO Once all samples are updated to use this method to register the logger with TensorRT, + //! we can eliminate the inheritance of Logger from ILogger + //! + nvinfer1::ILogger& getTRTLogger() { return *this; } + + //! + //! \brief Implementation of the nvinfer1::ILogger::log() virtual method + //! + //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the + //! inheritance from nvinfer1::ILogger + //! + void log(Severity severity, const char* msg) TRT_NOEXCEPT override { + LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; + } + + //! + //! \brief Method for controlling the verbosity of logging output + //! + //! \param severity The logger will only emit messages that have severity of this level or higher. + //! + void setReportableSeverity(Severity severity) { mReportableSeverity = severity; } + + //! + //! \brief Opaque handle that holds logging information for a particular test + //! + //! This object is an opaque handle to information used by the Logger to print test results. + //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used + //! with Logger::reportTest{Start,End}(). + //! + class TestAtom { + public: + TestAtom(TestAtom&&) = default; + + private: + friend class Logger; + + TestAtom(bool started, const std::string& name, const std::string& cmdline) + : mStarted(started), mName(name), mCmdline(cmdline) {} + + bool mStarted; + std::string mName; + std::string mCmdline; + }; + + //! + //! \brief Define a test for logging + //! + //! \param[in] name The name of the test. This should be a string starting with + //! "TensorRT" and containing dot-separated strings containing + //! the characters [A-Za-z0-9_]. + //! For example, "TensorRT.sample_googlenet" + //! \param[in] cmdline The command line used to reproduce the test + // + //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). + //! + static TestAtom defineTest(const std::string& name, const std::string& cmdline) { + return TestAtom(false, name, cmdline); + } + + //! + //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments + //! as input + //! + //! \param[in] name The name of the test + //! \param[in] argc The number of command-line arguments + //! \param[in] argv The array of command-line arguments (given as C strings) + //! + //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). + static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) { + auto cmdline = genCmdlineString(argc, argv); + return defineTest(name, cmdline); + } + + //! + //! \brief Report that a test has started. + //! + //! \pre reportTestStart() has not been called yet for the given testAtom + //! + //! \param[in] testAtom The handle to the test that has started + //! + static void reportTestStart(TestAtom& testAtom) { + reportTestResult(testAtom, TestResult::kRUNNING); + assert(!testAtom.mStarted); + testAtom.mStarted = true; + } + + //! + //! \brief Report that a test has ended. + //! + //! \pre reportTestStart() has been called for the given testAtom + //! + //! \param[in] testAtom The handle to the test that has ended + //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, + //! TestResult::kFAILED, TestResult::kWAIVED + //! + static void reportTestEnd(const TestAtom& testAtom, TestResult result) { + assert(result != TestResult::kRUNNING); + assert(testAtom.mStarted); + reportTestResult(testAtom, result); + } + + static int reportPass(const TestAtom& testAtom) { + reportTestEnd(testAtom, TestResult::kPASSED); + return EXIT_SUCCESS; + } + + static int reportFail(const TestAtom& testAtom) { + reportTestEnd(testAtom, TestResult::kFAILED); + return EXIT_FAILURE; + } + + static int reportWaive(const TestAtom& testAtom) { + reportTestEnd(testAtom, TestResult::kWAIVED); + return EXIT_SUCCESS; + } + + static int reportTest(const TestAtom& testAtom, bool pass) { + return pass ? reportPass(testAtom) : reportFail(testAtom); + } + + Severity getReportableSeverity() const { return mReportableSeverity; } + + private: + //! + //! \brief returns an appropriate string for prefixing a log message with the given severity + //! + static const char* severityPrefix(Severity severity) { + switch (severity) { + case Severity::kINTERNAL_ERROR: + return "[F] "; + case Severity::kERROR: + return "[E] "; + case Severity::kWARNING: + return "[W] "; + case Severity::kINFO: + return "[I] "; + case Severity::kVERBOSE: + return "[V] "; + default: + assert(0); + return ""; + } + } + + //! + //! \brief returns an appropriate string for prefixing a test result message with the given result + //! + static const char* testResultString(TestResult result) { + switch (result) { + case TestResult::kRUNNING: + return "RUNNING"; + case TestResult::kPASSED: + return "PASSED"; + case TestResult::kFAILED: + return "FAILED"; + case TestResult::kWAIVED: + return "WAIVED"; + default: + assert(0); + return ""; + } + } + + //! + //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity + //! + static std::ostream& severityOstream(Severity severity) { + return severity >= Severity::kINFO ? std::cout : std::cerr; + } + + //! + //! \brief method that implements logging test results + //! + static void reportTestResult(const TestAtom& testAtom, TestResult result) { + severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " + << testAtom.mCmdline << std::endl; + } + + //! + //! \brief generate a command line string from the given (argc, argv) values + //! + static std::string genCmdlineString(int argc, char const* const* argv) { + std::stringstream ss; + for (int i = 0; i < argc; i++) { + if (i > 0) + ss << " "; + ss << argv[i]; + } + return ss.str(); + } + + Severity mReportableSeverity; +}; + +namespace { + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE +//! +//! Example usage: +//! +//! LOG_VERBOSE(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO +//! +//! Example usage: +//! +//! LOG_INFO(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_INFO(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING +//! +//! Example usage: +//! +//! LOG_WARN(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_WARN(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR +//! +//! Example usage: +//! +//! LOG_ERROR(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_ERROR(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); +} + +//! +//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR +// ("fatal" severity) +//! +//! Example usage: +//! +//! LOG_FATAL(logger) << "hello world" << std::endl; +//! +inline LogStreamConsumer LOG_FATAL(const Logger& logger) { + return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); +} + +} // anonymous namespace + +#endif // TENSORRT_LOGGING_H diff --git a/yolov5/yolov5_trt10/src/macros.h b/yolov5/yolov5_trt10/src/macros.h new file mode 100644 index 00000000..17339a24 --- /dev/null +++ b/yolov5/yolov5_trt10/src/macros.h @@ -0,0 +1,29 @@ +#ifndef __MACROS_H +#define __MACROS_H + +#include + +#ifdef API_EXPORTS +#if defined(_MSC_VER) +#define API __declspec(dllexport) +#else +#define API __attribute__((visibility("default"))) +#endif +#else + +#if defined(_MSC_VER) +#define API __declspec(dllimport) +#else +#define API +#endif +#endif // API_EXPORTS + +#if NV_TENSORRT_MAJOR >= 8 +#define TRT_NOEXCEPT noexcept +#define TRT_CONST_ENQUEUE const +#else +#define TRT_NOEXCEPT +#define TRT_CONST_ENQUEUE +#endif + +#endif // __MACROS_H diff --git a/yolov5/yolov5_trt10/src/model.cpp b/yolov5/yolov5_trt10/src/model.cpp new file mode 100644 index 00000000..606dec59 --- /dev/null +++ b/yolov5/yolov5_trt10/src/model.cpp @@ -0,0 +1,331 @@ +#include "model.h" +#include "calibrator.h" +#include "config.h" + +#include +#include +#include +#include +#include +#include + +using namespace nvinfer1; + +// TensorRT weight files have a simple space delimited format: +// [type] [size] +static std::map loadWeights(const std::string file) { + std::cout << "Loading weights: " << file << std::endl; + std::map weightMap; + + // Open weights file + std::ifstream input(file); + assert(input.is_open() && "Unable to load weight file. please check if the .wts file path is right!!!!!!"); + + // Read number of weight blobs + int32_t count; + input >> count; + assert(count > 0 && "Invalid weight map file."); + + while (count--) { + Weights wt{DataType::kFLOAT, nullptr, 0}; + uint32_t size; + + // Read name and type of blob + std::string name; + input >> name >> std::dec >> size; + wt.type = DataType::kFLOAT; + + // Load blob + uint32_t* val = reinterpret_cast(malloc(sizeof(val) * size)); + for (uint32_t x = 0, y = size; x < y; ++x) { + input >> std::hex >> val[x]; + } + wt.values = val; + + wt.count = size; + weightMap[name] = wt; + } + + return weightMap; +} + +static int get_width(int x, float gw, int divisor = 8) { + return int(ceil((x * gw) / divisor)) * divisor; +} + +static int get_depth(int x, float gd) { + if (x == 1) + return 1; + int r = round(x * gd); + if (x * gd - int(x * gd) == 0.5 && (int(x * gd) % 2) == 0) { + --r; + } + return std::max(r, 1); +} + +static IScaleLayer* addBatchNorm2d(INetworkDefinition* network, std::map& weightMap, + ITensor& input, std::string lname, float eps) { + float* gamma = (float*)weightMap[lname + ".weight"].values; + float* beta = (float*)weightMap[lname + ".bias"].values; + float* mean = (float*)weightMap[lname + ".running_mean"].values; + float* var = (float*)weightMap[lname + ".running_var"].values; + int len = weightMap[lname + ".running_var"].count; + + float* scval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + scval[i] = gamma[i] / sqrt(var[i] + eps); + } + Weights scale{DataType::kFLOAT, scval, len}; + + float* shval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + shval[i] = beta[i] - mean[i] * gamma[i] / sqrt(var[i] + eps); + } + Weights shift{DataType::kFLOAT, shval, len}; + + float* pval = reinterpret_cast(malloc(sizeof(float) * len)); + for (int i = 0; i < len; i++) { + pval[i] = 1.0; + } + Weights power{DataType::kFLOAT, pval, len}; + + weightMap[lname + ".scale"] = scale; + weightMap[lname + ".shift"] = shift; + weightMap[lname + ".power"] = power; + IScaleLayer* scale_1 = network->addScale(input, ScaleMode::kCHANNEL, shift, scale, power); + assert(scale_1); + return scale_1; +} + +static ILayer* convBlock(INetworkDefinition* network, std::map& weightMap, ITensor& input, + int outch, int ksize, int s, int g, std::string lname) { + Weights emptywts{DataType::kFLOAT, nullptr, 0}; + int p = ksize / 3; + IConvolutionLayer* conv1 = + network->addConvolutionNd(input, outch, DimsHW{ksize, ksize}, weightMap[lname + ".conv.weight"], emptywts); + assert(conv1); + conv1->setStrideNd(DimsHW{s, s}); + conv1->setPaddingNd(DimsHW{p, p}); + conv1->setNbGroups(g); + conv1->setName((lname + ".conv").c_str()); + IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + ".bn", 1e-3); + + // silu = x * sigmoid + auto sig = network->addActivation(*bn1->getOutput(0), ActivationType::kSIGMOID); + assert(sig); + auto ew = network->addElementWise(*bn1->getOutput(0), *sig->getOutput(0), ElementWiseOperation::kPROD); + assert(ew); + return ew; +} + +static ILayer* focus(INetworkDefinition* network, std::map& weightMap, ITensor& input, int inch, + int outch, int ksize, std::string lname) { + ISliceLayer* s1 = network->addSlice(input, Dims3{0, 0, 0}, Dims3{inch, kInputH / 2, kInputW / 2}, Dims3{1, 2, 2}); + ISliceLayer* s2 = network->addSlice(input, Dims3{0, 1, 0}, Dims3{inch, kInputH / 2, kInputW / 2}, Dims3{1, 2, 2}); + ISliceLayer* s3 = network->addSlice(input, Dims3{0, 0, 1}, Dims3{inch, kInputH / 2, kInputW / 2}, Dims3{1, 2, 2}); + ISliceLayer* s4 = network->addSlice(input, Dims3{0, 1, 1}, Dims3{inch, kInputH / 2, kInputW / 2}, Dims3{1, 2, 2}); + ITensor* inputTensors[] = {s1->getOutput(0), s2->getOutput(0), s3->getOutput(0), s4->getOutput(0)}; + auto cat = network->addConcatenation(inputTensors, 4); + auto conv = convBlock(network, weightMap, *cat->getOutput(0), outch, ksize, 1, 1, lname + ".conv"); + return conv; +} + +static ILayer* bottleneck(INetworkDefinition* network, std::map& weightMap, ITensor& input, + int c1, int c2, bool shortcut, int g, float e, std::string lname) { + auto cv1 = convBlock(network, weightMap, input, (int)((float)c2 * e), 1, 1, 1, lname + ".cv1"); + auto cv2 = convBlock(network, weightMap, *cv1->getOutput(0), c2, 3, 1, g, lname + ".cv2"); + if (shortcut && c1 == c2) { + auto ew = network->addElementWise(input, *cv2->getOutput(0), ElementWiseOperation::kSUM); + return ew; + } + return cv2; +} + +static ILayer* bottleneckCSP(INetworkDefinition* network, std::map& weightMap, ITensor& input, + int c1, int c2, int n, bool shortcut, int g, float e, std::string lname) { + Weights emptywts{DataType::kFLOAT, nullptr, 0}; + int c_ = (int)((float)c2 * e); + auto cv1 = convBlock(network, weightMap, input, c_, 1, 1, 1, lname + ".cv1"); + auto cv2 = network->addConvolutionNd(input, c_, DimsHW{1, 1}, weightMap[lname + ".cv2.weight"], emptywts); + ITensor* y1 = cv1->getOutput(0); + for (int i = 0; i < n; i++) { + auto b = bottleneck(network, weightMap, *y1, c_, c_, shortcut, g, 1.0, lname + ".m." + std::to_string(i)); + y1 = b->getOutput(0); + } + auto cv3 = network->addConvolutionNd(*y1, c_, DimsHW{1, 1}, weightMap[lname + ".cv3.weight"], emptywts); + + ITensor* inputTensors[] = {cv3->getOutput(0), cv2->getOutput(0)}; + auto cat = network->addConcatenation(inputTensors, 2); + + IScaleLayer* bn = addBatchNorm2d(network, weightMap, *cat->getOutput(0), lname + ".bn", 1e-4); + auto lr = network->addActivation(*bn->getOutput(0), ActivationType::kLEAKY_RELU); + lr->setAlpha(0.1); + + auto cv4 = convBlock(network, weightMap, *lr->getOutput(0), c2, 1, 1, 1, lname + ".cv4"); + return cv4; +} + +static ILayer* C3(INetworkDefinition* network, std::map& weightMap, ITensor& input, int c1, + int c2, int n, bool shortcut, int g, float e, std::string lname) { + int c_ = (int)((float)c2 * e); + auto cv1 = convBlock(network, weightMap, input, c_, 1, 1, 1, lname + ".cv1"); + auto cv2 = convBlock(network, weightMap, input, c_, 1, 1, 1, lname + ".cv2"); + ITensor* y1 = cv1->getOutput(0); + for (int i = 0; i < n; i++) { + auto b = bottleneck(network, weightMap, *y1, c_, c_, shortcut, g, 1.0, lname + ".m." + std::to_string(i)); + y1 = b->getOutput(0); + } + + ITensor* inputTensors[] = {y1, cv2->getOutput(0)}; + auto cat = network->addConcatenation(inputTensors, 2); + + auto cv3 = convBlock(network, weightMap, *cat->getOutput(0), c2, 1, 1, 1, lname + ".cv3"); + return cv3; +} + +static ILayer* SPP(INetworkDefinition* network, std::map& weightMap, ITensor& input, int c1, + int c2, int k1, int k2, int k3, std::string lname) { + int c_ = c1 / 2; + auto cv1 = convBlock(network, weightMap, input, c_, 1, 1, 1, lname + ".cv1"); + + auto pool1 = network->addPoolingNd(*cv1->getOutput(0), PoolingType::kMAX, DimsHW{k1, k1}); + pool1->setPaddingNd(DimsHW{k1 / 2, k1 / 2}); + pool1->setStrideNd(DimsHW{1, 1}); + auto pool2 = network->addPoolingNd(*cv1->getOutput(0), PoolingType::kMAX, DimsHW{k2, k2}); + pool2->setPaddingNd(DimsHW{k2 / 2, k2 / 2}); + pool2->setStrideNd(DimsHW{1, 1}); + auto pool3 = network->addPoolingNd(*cv1->getOutput(0), PoolingType::kMAX, DimsHW{k3, k3}); + pool3->setPaddingNd(DimsHW{k3 / 2, k3 / 2}); + pool3->setStrideNd(DimsHW{1, 1}); + + ITensor* inputTensors[] = {cv1->getOutput(0), pool1->getOutput(0), pool2->getOutput(0), pool3->getOutput(0)}; + auto cat = network->addConcatenation(inputTensors, 4); + + auto cv2 = convBlock(network, weightMap, *cat->getOutput(0), c2, 1, 1, 1, lname + ".cv2"); + return cv2; +} + +static ILayer* SPPF(INetworkDefinition* network, std::map& weightMap, ITensor& input, int c1, + int c2, int k, std::string lname) { + int c_ = c1 / 2; + auto cv1 = convBlock(network, weightMap, input, c_, 1, 1, 1, lname + ".cv1"); + + auto pool1 = network->addPoolingNd(*cv1->getOutput(0), PoolingType::kMAX, DimsHW{k, k}); + pool1->setPaddingNd(DimsHW{k / 2, k / 2}); + pool1->setStrideNd(DimsHW{1, 1}); + auto pool2 = network->addPoolingNd(*pool1->getOutput(0), PoolingType::kMAX, DimsHW{k, k}); + pool2->setPaddingNd(DimsHW{k / 2, k / 2}); + pool2->setStrideNd(DimsHW{1, 1}); + auto pool3 = network->addPoolingNd(*pool2->getOutput(0), PoolingType::kMAX, DimsHW{k, k}); + pool3->setPaddingNd(DimsHW{k / 2, k / 2}); + pool3->setStrideNd(DimsHW{1, 1}); + ITensor* inputTensors[] = {cv1->getOutput(0), pool1->getOutput(0), pool2->getOutput(0), pool3->getOutput(0)}; + auto cat = network->addConcatenation(inputTensors, 4); + auto cv2 = convBlock(network, weightMap, *cat->getOutput(0), c2, 1, 1, 1, lname + ".cv2"); + return cv2; +} + +static ILayer* Proto(INetworkDefinition* network, std::map& weightMap, ITensor& input, int c_, + int c2, std::string lname) { + auto cv1 = convBlock(network, weightMap, input, c_, 3, 1, 1, lname + ".cv1"); + + auto upsample = network->addResize(*cv1->getOutput(0)); + assert(upsample); + upsample->setResizeMode(nvinfer1::InterpolationMode::kNEAREST); + const float scales[] = {1, 1, 2, 2}; + upsample->setScales(scales, 4); + + auto cv2 = convBlock(network, weightMap, *upsample->getOutput(0), c_, 3, 1, 1, lname + ".cv2"); + auto cv3 = convBlock(network, weightMap, *cv2->getOutput(0), c2, 1, 1, 1, lname + ".cv3"); + assert(cv3); + return cv3; +} + +static std::vector> getAnchors(std::map& weightMap, std::string lname) { + std::vector> anchors; + Weights wts = weightMap[lname + ".anchor_grid"]; + int anchor_len = kNumAnchor * 2; + for (int i = 0; i < wts.count / anchor_len; i++) { + auto* p = (const float*)wts.values + i * anchor_len; + std::vector anchor(p, p + anchor_len); + anchors.push_back(anchor); + } + return anchors; +} + +nvinfer1::IHostMemory* build_cls_engine(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, + DataType dt, float& gd, float& gw, std::string& wts_name) { + // INetworkDefinition *network = builder->createNetworkV2(0U); + INetworkDefinition* network = + builder->createNetworkV2(1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)); + + // Create input tensor + ITensor* data = network->addInput(kInputTensorName, dt, Dims4{maxBatchSize, 3, kClsInputH, kClsInputW}); + assert(data); + std::map weightMap = loadWeights(wts_name); + + // Backbone + auto conv0 = convBlock(network, weightMap, *data, get_width(64, gw), 6, 2, 1, "model.0"); + assert(conv0); + auto conv1 = convBlock(network, weightMap, *conv0->getOutput(0), get_width(128, gw), 3, 2, 1, "model.1"); + auto bottleneck_CSP2 = C3(network, weightMap, *conv1->getOutput(0), get_width(128, gw), get_width(128, gw), + get_depth(3, gd), true, 1, 0.5, "model.2"); + auto conv3 = convBlock(network, weightMap, *bottleneck_CSP2->getOutput(0), get_width(256, gw), 3, 2, 1, "model.3"); + auto bottleneck_csp4 = C3(network, weightMap, *conv3->getOutput(0), get_width(256, gw), get_width(256, gw), + get_depth(6, gd), true, 1, 0.5, "model.4"); + auto conv5 = convBlock(network, weightMap, *bottleneck_csp4->getOutput(0), get_width(512, gw), 3, 2, 1, "model.5"); + auto bottleneck_csp6 = C3(network, weightMap, *conv5->getOutput(0), get_width(512, gw), get_width(512, gw), + get_depth(9, gd), true, 1, 0.5, "model.6"); + auto conv7 = convBlock(network, weightMap, *bottleneck_csp6->getOutput(0), get_width(1024, gw), 3, 2, 1, "model.7"); + auto bottleneck_csp8 = C3(network, weightMap, *conv7->getOutput(0), get_width(1024, gw), get_width(1024, gw), + get_depth(3, gd), true, 1, 0.5, "model.8"); + + // Head + auto conv_class = convBlock(network, weightMap, *bottleneck_csp8->getOutput(0), 1280, 1, 1, 1, "model.9.conv"); + int k = kClsInputH / 32; + IPoolingLayer* pool2 = network->addPoolingNd(*conv_class->getOutput(0), PoolingType::kAVERAGE, DimsHW{k, k}); + assert(pool2); + auto shuffle_0 = network->addShuffle(*pool2->getOutput(0)); + shuffle_0->setReshapeDimensions(nvinfer1::Dims2{kBatchSize, 1280}); + auto linear_weight = weightMap["model.9.linear.weight"]; + auto constant_weight = network->addConstant(nvinfer1::Dims2{kClsNumClass, 1280}, linear_weight); + auto constant_bias = + network->addConstant(nvinfer1::Dims2{kBatchSize, kClsNumClass}, weightMap["model.9.linear.bias"]); + auto linear_matrix_multipy = + network->addMatrixMultiply(*shuffle_0->getOutput(0), nvinfer1::MatrixOperation::kNONE, + *constant_weight->getOutput(0), nvinfer1::MatrixOperation::kTRANSPOSE); + auto yolo = network->addElementWise(*linear_matrix_multipy->getOutput(0), *constant_bias->getOutput(0), + nvinfer1::ElementWiseOperation::kSUM); + assert(yolo); + + yolo->getOutput(0)->setName(kOutputTensorName); + network->markOutput(*yolo->getOutput(0)); + + // Engine config + config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, 16 * (1 << 20)); + +#if defined(USE_FP16) + config->setFlag(BuilderFlag::kFP16); +#elif defined(USE_INT8) + std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; + assert(builder->platformHasFastInt8()); + config->setFlag(BuilderFlag::kINT8); + Int8EntropyCalibrator2* calibrator = + new Int8EntropyCalibrator2(1, kClsInputW, kClsInputW, "./coco_calib/", "int8calib.table", kInputTensorName); + config->setInt8Calibrator(calibrator); +#endif + + std::cout << "Building engine, please wait for a while..." << std::endl; + nvinfer1::IHostMemory* serialized_model = builder->buildSerializedNetwork(*network, *config); + std::cout << "Build engine successfully!" << std::endl; + + // Don't need the network any more + delete network; + + // Release host memory + for (auto& mem : weightMap) { + free((void*)(mem.second.values)); + } + + return serialized_model; +} diff --git a/yolov5/yolov5_trt10/src/model.h b/yolov5/yolov5_trt10/src/model.h new file mode 100644 index 00000000..3bfbd93c --- /dev/null +++ b/yolov5/yolov5_trt10/src/model.h @@ -0,0 +1,8 @@ +#pragma once + +#include +#include + +nvinfer1::IHostMemory* build_cls_engine(unsigned int maxBatchSize, nvinfer1::IBuilder* builder, + nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, float& gd, float& gw, + std::string& wts_name); diff --git a/yolov5/yolov5_trt10/src/postprocess.cpp b/yolov5/yolov5_trt10/src/postprocess.cpp new file mode 100644 index 00000000..c9d08376 --- /dev/null +++ b/yolov5/yolov5_trt10/src/postprocess.cpp @@ -0,0 +1,195 @@ +#include "postprocess.h" +#include "utils.h" + +cv::Rect get_rect(cv::Mat& img, float bbox[4]) { + float l, r, t, b; + float r_w = kInputW / (img.cols * 1.0); + float r_h = kInputH / (img.rows * 1.0); + if (r_h > r_w) { + l = bbox[0] - bbox[2] / 2.f; + r = bbox[0] + bbox[2] / 2.f; + t = bbox[1] - bbox[3] / 2.f - (kInputH - r_w * img.rows) / 2; + b = bbox[1] + bbox[3] / 2.f - (kInputH - r_w * img.rows) / 2; + l = l / r_w; + r = r / r_w; + t = t / r_w; + b = b / r_w; + } else { + l = bbox[0] - bbox[2] / 2.f - (kInputW - r_h * img.cols) / 2; + r = bbox[0] + bbox[2] / 2.f - (kInputW - r_h * img.cols) / 2; + t = bbox[1] - bbox[3] / 2.f; + b = bbox[1] + bbox[3] / 2.f; + l = l / r_h; + r = r / r_h; + t = t / r_h; + b = b / r_h; + } + return cv::Rect(round(l), round(t), round(r - l), round(b - t)); +} + +static float iou(float lbox[4], float rbox[4]) { + float interBox[] = { + (std::max)(lbox[0] - lbox[2] / 2.f, rbox[0] - rbox[2] / 2.f), //left + (std::min)(lbox[0] + lbox[2] / 2.f, rbox[0] + rbox[2] / 2.f), //right + (std::max)(lbox[1] - lbox[3] / 2.f, rbox[1] - rbox[3] / 2.f), //top + (std::min)(lbox[1] + lbox[3] / 2.f, rbox[1] + rbox[3] / 2.f), //bottom + }; + + if (interBox[2] > interBox[3] || interBox[0] > interBox[1]) + return 0.0f; + + float interBoxS = (interBox[1] - interBox[0]) * (interBox[3] - interBox[2]); + return interBoxS / (lbox[2] * lbox[3] + rbox[2] * rbox[3] - interBoxS); +} + +static bool cmp(const Detection& a, const Detection& b) { + return a.conf > b.conf; +} + +void nms(std::vector& res, float* output, float conf_thresh, float nms_thresh) { + int det_size = sizeof(Detection) / sizeof(float); + std::map> m; + for (int i = 0; i < output[0] && i < kMaxNumOutputBbox; i++) { + if (output[1 + det_size * i + 4] <= conf_thresh) + continue; + Detection det; + memcpy(&det, &output[1 + det_size * i], det_size * sizeof(float)); + if (m.count(det.class_id) == 0) + m.emplace(det.class_id, std::vector()); + m[det.class_id].push_back(det); + } + for (auto it = m.begin(); it != m.end(); it++) { + auto& dets = it->second; + std::sort(dets.begin(), dets.end(), cmp); + for (size_t m = 0; m < dets.size(); ++m) { + auto& item = dets[m]; + res.push_back(item); + for (size_t n = m + 1; n < dets.size(); ++n) { + if (iou(item.bbox, dets[n].bbox) > nms_thresh) { + dets.erase(dets.begin() + n); + --n; + } + } + } + } +} + +void batch_nms(std::vector>& res_batch, float* output, int batch_size, int output_size, + float conf_thresh, float nms_thresh) { + res_batch.resize(batch_size); + for (int i = 0; i < batch_size; i++) { + nms(res_batch[i], &output[i * output_size], conf_thresh, nms_thresh); + } +} + +void draw_bbox(std::vector& img_batch, std::vector>& res_batch) { + for (size_t i = 0; i < img_batch.size(); i++) { + auto& res = res_batch[i]; + cv::Mat img = img_batch[i]; + for (size_t j = 0; j < res.size(); j++) { + cv::Rect r = get_rect(img, res[j].bbox); + cv::rectangle(img, r, cv::Scalar(0x27, 0xC1, 0x36), 2); + cv::putText(img, std::to_string((int)res[j].class_id), cv::Point(r.x, r.y - 1), cv::FONT_HERSHEY_PLAIN, 1.2, + cv::Scalar(0xFF, 0xFF, 0xFF), 2); + } + } +} + +static cv::Rect get_downscale_rect(float bbox[4], float scale) { + float left = bbox[0] - bbox[2] / 2; + float top = bbox[1] - bbox[3] / 2; + float right = bbox[0] + bbox[2] / 2; + float bottom = bbox[1] + bbox[3] / 2; + left /= scale; + top /= scale; + right /= scale; + bottom /= scale; + return cv::Rect(round(left), round(top), round(right - left), round(bottom - top)); +} + +std::vector process_mask(const float* proto, int proto_size, std::vector& dets) { + std::vector masks; + for (size_t i = 0; i < dets.size(); i++) { + cv::Mat mask_mat = cv::Mat::zeros(kInputH / 4, kInputW / 4, CV_32FC1); + auto r = get_downscale_rect(dets[i].bbox, 4); + for (int x = r.x; x < r.x + r.width; x++) { + for (int y = r.y; y < r.y + r.height; y++) { + float e = 0.0f; + for (int j = 0; j < 32; j++) { + e += dets[i].mask[j] * proto[j * proto_size / 32 + y * mask_mat.cols + x]; + } + e = 1.0f / (1.0f + expf(-e)); + mask_mat.at(y, x) = e; + } + } + cv::resize(mask_mat, mask_mat, cv::Size(kInputW, kInputH)); + masks.push_back(mask_mat); + } + return masks; +} + +cv::Mat scale_mask(cv::Mat mask, cv::Mat img) { + int x, y, w, h; + float r_w = kInputW / (img.cols * 1.0); + float r_h = kInputH / (img.rows * 1.0); + if (r_h > r_w) { + w = kInputW; + h = r_w * img.rows; + x = 0; + y = (kInputH - h) / 2; + } else { + w = r_h * img.cols; + h = kInputH; + x = (kInputW - w) / 2; + y = 0; + } + cv::Rect r(x, y, w, h); + cv::Mat res; + cv::resize(mask(r), res, img.size()); + return res; +} + +void draw_mask_bbox(cv::Mat& img, std::vector& dets, std::vector& masks, + std::unordered_map& labels_map) { + static std::vector colors = {0xFF3838, 0xFF9D97, 0xFF701F, 0xFFB21D, 0xCFD231, 0x48F90A, 0x92CC17, + 0x3DDB86, 0x1A9334, 0x00D4BB, 0x2C99A8, 0x00C2FF, 0x344593, 0x6473FF, + 0x0018EC, 0x8438FF, 0x520085, 0xCB38FF, 0xFF95C8, 0xFF37C7}; + for (size_t i = 0; i < dets.size(); i++) { + cv::Mat img_mask = scale_mask(masks[i], img); + auto color = colors[(int)dets[i].class_id % colors.size()]; + auto bgr = cv::Scalar(color & 0xFF, color >> 8 & 0xFF, color >> 16 & 0xFF); + + cv::Rect r = get_rect(img, dets[i].bbox); + for (int x = r.x; x < r.x + r.width; x++) { + for (int y = r.y; y < r.y + r.height; y++) { + float val = img_mask.at(y, x); + if (val <= 0.5) + continue; + img.at(y, x)[0] = img.at(y, x)[0] / 2 + bgr[0] / 2; + img.at(y, x)[1] = img.at(y, x)[1] / 2 + bgr[1] / 2; + img.at(y, x)[2] = img.at(y, x)[2] / 2 + bgr[2] / 2; + } + } + + cv::rectangle(img, r, bgr, 2); + + // Get the size of the text + cv::Size textSize = + cv::getTextSize(labels_map[(int)dets[i].class_id] + " " + to_string_with_precision(dets[i].conf), + cv::FONT_HERSHEY_PLAIN, 1.2, 2, NULL); + // Set the top left corner of the rectangle + cv::Point topLeft(r.x, r.y - textSize.height); + + // Set the bottom right corner of the rectangle + cv::Point bottomRight(r.x + textSize.width, r.y + textSize.height); + + // Set the thickness of the rectangle lines + int lineThickness = 2; + + // Draw the rectangle on the image + cv::rectangle(img, topLeft, bottomRight, bgr, -1); + + cv::putText(img, labels_map[(int)dets[i].class_id] + " " + to_string_with_precision(dets[i].conf), + cv::Point(r.x, r.y + 4), cv::FONT_HERSHEY_PLAIN, 1.2, cv::Scalar::all(0xFF), 2); + } +} diff --git a/yolov5/yolov5_trt10/src/postprocess.h b/yolov5/yolov5_trt10/src/postprocess.h new file mode 100644 index 00000000..43f83c50 --- /dev/null +++ b/yolov5/yolov5_trt10/src/postprocess.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include "types.h" + +cv::Rect get_rect(cv::Mat& img, float bbox[4]); + +void nms(std::vector& res, float* output, float conf_thresh, float nms_thresh = 0.5); + +void batch_nms(std::vector>& batch_res, float* output, int batch_size, int output_size, + float conf_thresh, float nms_thresh = 0.5); + +void draw_bbox(std::vector& img_batch, std::vector>& res_batch); + +std::vector process_mask(const float* proto, int proto_size, std::vector& dets); + +void draw_mask_bbox(cv::Mat& img, std::vector& dets, std::vector& masks, + std::unordered_map& labels_map); diff --git a/yolov5/yolov5_trt10/src/preprocess.cu b/yolov5/yolov5_trt10/src/preprocess.cu new file mode 100644 index 00000000..7b046e06 --- /dev/null +++ b/yolov5/yolov5_trt10/src/preprocess.cu @@ -0,0 +1,146 @@ +#include "cuda_utils.h" +#include "preprocess.h" + +static uint8_t* img_buffer_host = nullptr; +static uint8_t* img_buffer_device = nullptr; + +struct AffineMatrix { + float value[6]; +}; + +__global__ void warpaffine_kernel(uint8_t* src, int src_line_size, int src_width, int src_height, float* dst, + int dst_width, int dst_height, uint8_t const_value_st, AffineMatrix d2s, int edge) { + int position = blockDim.x * blockIdx.x + threadIdx.x; + if (position >= edge) + return; + + float m_x1 = d2s.value[0]; + float m_y1 = d2s.value[1]; + float m_z1 = d2s.value[2]; + float m_x2 = d2s.value[3]; + float m_y2 = d2s.value[4]; + float m_z2 = d2s.value[5]; + + int dx = position % dst_width; + int dy = position / dst_width; + float src_x = m_x1 * dx + m_y1 * dy + m_z1 + 0.5f; + float src_y = m_x2 * dx + m_y2 * dy + m_z2 + 0.5f; + float c0, c1, c2; + + if (src_x <= -1 || src_x >= src_width || src_y <= -1 || src_y >= src_height) { + // out of range + c0 = const_value_st; + c1 = const_value_st; + c2 = const_value_st; + } else { + int y_low = floorf(src_y); + int x_low = floorf(src_x); + int y_high = y_low + 1; + int x_high = x_low + 1; + + uint8_t const_value[] = {const_value_st, const_value_st, const_value_st}; + float ly = src_y - y_low; + float lx = src_x - x_low; + float hy = 1 - ly; + float hx = 1 - lx; + float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + uint8_t* v1 = const_value; + uint8_t* v2 = const_value; + uint8_t* v3 = const_value; + uint8_t* v4 = const_value; + + if (y_low >= 0) { + if (x_low >= 0) + v1 = src + y_low * src_line_size + x_low * 3; + + if (x_high < src_width) + v2 = src + y_low * src_line_size + x_high * 3; + } + + if (y_high < src_height) { + if (x_low >= 0) + v3 = src + y_high * src_line_size + x_low * 3; + + if (x_high < src_width) + v4 = src + y_high * src_line_size + x_high * 3; + } + + c0 = w1 * v1[0] + w2 * v2[0] + w3 * v3[0] + w4 * v4[0]; + c1 = w1 * v1[1] + w2 * v2[1] + w3 * v3[1] + w4 * v4[1]; + c2 = w1 * v1[2] + w2 * v2[2] + w3 * v3[2] + w4 * v4[2]; + } + + // bgr to rgb + float t = c2; + c2 = c0; + c0 = t; + + // normalization + c0 = c0 / 255.0f; + c1 = c1 / 255.0f; + c2 = c2 / 255.0f; + + // rgbrgbrgb to rrrgggbbb + int area = dst_width * dst_height; + float* pdst_c0 = dst + dy * dst_width + dx; + float* pdst_c1 = pdst_c0 + area; + float* pdst_c2 = pdst_c1 + area; + *pdst_c0 = c0; + *pdst_c1 = c1; + *pdst_c2 = c2; +} + +void cuda_preprocess(uint8_t* src, int src_width, int src_height, float* dst, int dst_width, int dst_height, + cudaStream_t stream) { + + int img_size = src_width * src_height * 3; + // copy data to pinned memory + memcpy(img_buffer_host, src, img_size); + // copy data to device memory + CUDA_CHECK(cudaMemcpyAsync(img_buffer_device, img_buffer_host, img_size, cudaMemcpyHostToDevice, stream)); + + AffineMatrix s2d, d2s; + float scale = std::min(dst_height / (float)src_height, dst_width / (float)src_width); + + s2d.value[0] = scale; + s2d.value[1] = 0; + s2d.value[2] = -scale * src_width * 0.5 + dst_width * 0.5; + s2d.value[3] = 0; + s2d.value[4] = scale; + s2d.value[5] = -scale * src_height * 0.5 + dst_height * 0.5; + + cv::Mat m2x3_s2d(2, 3, CV_32F, s2d.value); + cv::Mat m2x3_d2s(2, 3, CV_32F, d2s.value); + cv::invertAffineTransform(m2x3_s2d, m2x3_d2s); + + memcpy(d2s.value, m2x3_d2s.ptr(0), sizeof(d2s.value)); + + int jobs = dst_height * dst_width; + int threads = 256; + int blocks = ceil(jobs / (float)threads); + + warpaffine_kernel<<>>(img_buffer_device, src_width * 3, src_width, src_height, dst, + dst_width, dst_height, 128, d2s, jobs); +} + +void cuda_batch_preprocess(std::vector& img_batch, float* dst, int dst_width, int dst_height, + cudaStream_t stream) { + int dst_size = dst_width * dst_height * 3; + for (size_t i = 0; i < img_batch.size(); i++) { + cuda_preprocess(img_batch[i].ptr(), img_batch[i].cols, img_batch[i].rows, &dst[dst_size * i], dst_width, + dst_height, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + } +} + +void cuda_preprocess_init(int max_image_size) { + // prepare input data in pinned memory + CUDA_CHECK(cudaMallocHost((void**)&img_buffer_host, max_image_size * 3)); + // prepare input data in device memory + CUDA_CHECK(cudaMalloc((void**)&img_buffer_device, max_image_size * 3)); +} + +void cuda_preprocess_destroy() { + CUDA_CHECK(cudaFree(img_buffer_device)); + CUDA_CHECK(cudaFreeHost(img_buffer_host)); +} diff --git a/yolov5/yolov5_trt10/src/preprocess.h b/yolov5/yolov5_trt10/src/preprocess.h new file mode 100644 index 00000000..ee81cf16 --- /dev/null +++ b/yolov5/yolov5_trt10/src/preprocess.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include +#include + +void cuda_preprocess_init(int max_image_size); +void cuda_preprocess_destroy(); +void cuda_preprocess(uint8_t* src, int src_width, int src_height, float* dst, int dst_width, int dst_height, + cudaStream_t stream); +void cuda_batch_preprocess(std::vector& img_batch, float* dst, int dst_width, int dst_height, + cudaStream_t stream); diff --git a/yolov5/yolov5_trt10/src/types.h b/yolov5/yolov5_trt10/src/types.h new file mode 100644 index 00000000..4fc576c9 --- /dev/null +++ b/yolov5/yolov5_trt10/src/types.h @@ -0,0 +1,16 @@ +#pragma once + +#include "config.h" + +struct YoloKernel { + int width; + int height; + float anchors[kNumAnchor * 2]; +}; + +struct alignas(float) Detection { + float bbox[4]; // center_x center_y w h + float conf; // bbox_conf * cls_conf + float class_id; + float mask[32]; +}; diff --git a/yolov5/yolov5_trt10/src/utils.h b/yolov5/yolov5_trt10/src/utils.h new file mode 100644 index 00000000..436950e5 --- /dev/null +++ b/yolov5/yolov5_trt10/src/utils.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +static inline int read_files_in_dir(const char* p_dir_name, std::vector& file_names) { + DIR* p_dir = opendir(p_dir_name); + if (p_dir == nullptr) { + return -1; + } + + struct dirent* p_file = nullptr; + while ((p_file = readdir(p_dir)) != nullptr) { + if (strcmp(p_file->d_name, ".") != 0 && strcmp(p_file->d_name, "..") != 0) { + //std::string cur_file_name(p_dir_name); + //cur_file_name += "/"; + //cur_file_name += p_file->d_name; + std::string cur_file_name(p_file->d_name); + file_names.push_back(cur_file_name); + } + } + + closedir(p_dir); + return 0; +} + +// Function to trim leading and trailing whitespace from a string +static inline std::string trim_leading_whitespace(const std::string& str) { + size_t first = str.find_first_not_of(' '); + if (std::string::npos == first) { + return str; + } + size_t last = str.find_last_not_of(' '); + return str.substr(first, (last - first + 1)); +} + +// Src: https://stackoverflow.com/questions/16605967 +static inline std::string to_string_with_precision(const float a_value, const int n = 2) { + std::ostringstream out; + out.precision(n); + out << std::fixed << a_value; + return out.str(); +} + +static inline int read_labels(const std::string labels_filename, std::unordered_map& labels_map) { + + std::ifstream file(labels_filename); + // Read each line of the file + std::string line; + int index = 0; + while (std::getline(file, line)) { + // Strip the line of any leading or trailing whitespace + line = trim_leading_whitespace(line); + + // Add the stripped line to the labels_map, using the loop index as the key + labels_map[index] = line; + index++; + } + // Close the file + file.close(); + + return 0; +} diff --git a/yolov5/yolov5_trt10/yolov5_cls.cpp b/yolov5/yolov5_trt10/yolov5_cls.cpp new file mode 100644 index 00000000..be30487c --- /dev/null +++ b/yolov5/yolov5_trt10/yolov5_cls.cpp @@ -0,0 +1,307 @@ +#include "calibrator.h" +#include "config.h" +#include "cuda_utils.h" +#include "logging.h" +#include "model.h" +#include "utils.h" + +#include +#include +#include +#include +#include + +using namespace nvinfer1; + +static Logger gLogger; +const static int kOutputSize = kClsNumClass; + +cv::Mat cls_preprocess_img(cv::Mat& src, int target_w, int target_h) { + //imh, imw = im.shape[:2] + auto imh = src.rows; + auto imw = src.cols; + // m = min(imh, imw) # min dimension + auto m = std::min(imh, imw); + // top, left = (imh - m) // 2, (imw - m) // 2 + auto top = (imh - m) / 2; + auto left = (imw - m) / 2; + // return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR) + auto crop = src(cv::Rect(left, top, m, m)); + cv::Mat dst; + cv::resize(crop, dst, cv::Size(target_w, target_h), 0, 0, cv::INTER_LINEAR); + return dst; +} + +void batch_preprocess(std::vector& imgs, float* output) { + for (size_t b = 0; b < imgs.size(); b++) { + cv::Mat img; + // cv::resize(imgs[b], img, cv::Size(kClsInputW, kClsInputH)); + img = cls_preprocess_img(imgs[b], kClsInputW, kClsInputH); + int i = 0; + for (int row = 0; row < img.rows; ++row) { + uchar* uc_pixel = img.data + row * img.step; + for (int col = 0; col < img.cols; ++col) { + output[b * 3 * img.rows * img.cols + i] = ((float)uc_pixel[2] / 255.0 - 0.485) / 0.229; // R - 0.485 + output[b * 3 * img.rows * img.cols + i + img.rows * img.cols] = + ((float)uc_pixel[1] / 255.0 - 0.456) / 0.224; + output[b * 3 * img.rows * img.cols + i + 2 * img.rows * img.cols] = + ((float)uc_pixel[0] / 255.0 - 0.406) / 0.225; + uc_pixel += 3; + ++i; + } + } + } +} + +std::vector softmax(float* prob, int n) { + std::vector res; + float sum = 0.0f; + float t; + for (int i = 0; i < n; i++) { + t = expf(prob[i]); + res.push_back(t); + sum += t; + } + for (int i = 0; i < n; i++) { + res[i] /= sum; + } + return res; +} + +std::vector topk(const std::vector& vec, int k) { + std::vector topk_index; + std::vector vec_index(vec.size()); + std::iota(vec_index.begin(), vec_index.end(), 0); + + std::sort(vec_index.begin(), vec_index.end(), + [&vec](size_t index_1, size_t index_2) { return vec[index_1] > vec[index_2]; }); + + int k_num = std::min(vec.size(), k); + + for (int i = 0; i < k_num; ++i) { + topk_index.push_back(vec_index[i]); + } + + return topk_index; +} + +std::vector read_classes(std::string file_name) { + std::vector classes; + std::ifstream ifs(file_name, std::ios::in); + if (!ifs.is_open()) { + std::cerr << file_name << " is not found, pls refer to README and download it." << std::endl; + assert(0); + } + std::string s; + while (std::getline(ifs, s)) { + classes.push_back(s); + } + ifs.close(); + return classes; +} + +bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, float& gd, float& gw, + std::string& img_dir) { + if (argc < 4) + return false; + if (std::string(argv[1]) == "-s" && (argc == 5 || argc == 7)) { + wts = std::string(argv[2]); + engine = std::string(argv[3]); + auto net = std::string(argv[4]); + if (net[0] == 'n') { + gd = 0.33; + gw = 0.25; + } else if (net[0] == 's') { + gd = 0.33; + gw = 0.50; + } else if (net[0] == 'm') { + gd = 0.67; + gw = 0.75; + } else if (net[0] == 'l') { + gd = 1.0; + gw = 1.0; + } else if (net[0] == 'x') { + gd = 1.33; + gw = 1.25; + } else if (net[0] == 'c' && argc == 7) { + gd = atof(argv[5]); + gw = atof(argv[6]); + } else { + return false; + } + } else if (std::string(argv[1]) == "-d" && argc == 4) { + engine = std::string(argv[2]); + img_dir = std::string(argv[3]); + } else { + return false; + } + return true; +} + +void prepare_buffers(ICudaEngine* engine, float** gpu_input_buffer, float** gpu_output_buffer, float** cpu_input_buffer, + float** cpu_output_buffer) { + assert(engine->getNbIOTensors() == 2); + // Create GPU buffers on device + CUDA_CHECK(cudaMalloc((void**)gpu_input_buffer, kBatchSize * 3 * kClsInputH * kClsInputW * sizeof(float))); + CUDA_CHECK(cudaMalloc((void**)gpu_output_buffer, kBatchSize * kOutputSize * sizeof(float))); + + *cpu_input_buffer = new float[kBatchSize * 3 * kClsInputH * kClsInputW]; + *cpu_output_buffer = new float[kBatchSize * kOutputSize]; +} + +void infer(IExecutionContext& context, cudaStream_t& stream, void** buffers, float* input, float* output, + int batchSize) { + CUDA_CHECK(cudaMemcpyAsync(buffers[0], input, batchSize * 3 * kClsInputH * kClsInputW * sizeof(float), + cudaMemcpyHostToDevice, stream)); + // input, output + context.setInputTensorAddress(kInputTensorName, buffers[0]); + context.setOutputTensorAddress(kOutputTensorName, buffers[1]); + context.enqueueV3(stream); + CUDA_CHECK(cudaMemcpyAsync(output, buffers[1], batchSize * kOutputSize * sizeof(float), cudaMemcpyDeviceToHost, + stream)); + cudaStreamSynchronize(stream); +} + +void serialize_engine(unsigned int max_batchsize, float& gd, float& gw, std::string& wts_name, + std::string& engine_name) { + // Create builder + IBuilder* builder = createInferBuilder(gLogger); + IBuilderConfig* config = builder->createBuilderConfig(); + + // Create model to populate the network, then set the outputs and create an engine + IHostMemory* serialized_engine = nullptr; + //engine = buildEngineYolov8Cls(max_batchsize, builder, config, DataType::kFLOAT, gd, gw, wts_name); + serialized_engine = build_cls_engine(1, builder, config, DataType::kFLOAT, gd, gw, wts_name); + + // Save engine to file + std::ofstream p(engine_name, std::ios::binary); + if (!p) { + std::cerr << "Could not open plan output file" << std::endl; + assert(false); + } + p.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + + delete serialized_engine; + delete config; + delete builder; +} + +void deserialize_engine(std::string& engine_name, IRuntime** runtime, ICudaEngine** engine, + IExecutionContext** context) { + std::ifstream file(engine_name, std::ios::binary); + if (!file.good()) { + std::cerr << "read " << engine_name << " error!" << std::endl; + assert(false); + } + size_t size = 0; + file.seekg(0, file.end); + size = file.tellg(); + file.seekg(0, file.beg); + char* serialized_engine = new char[size]; + assert(serialized_engine); + file.read(serialized_engine, size); + file.close(); + + *runtime = createInferRuntime(gLogger); + assert(*runtime); + *engine = (*runtime)->deserializeCudaEngine(serialized_engine, size); + assert(*engine); + *context = (*engine)->createExecutionContext(); + assert(*context); + delete[] serialized_engine; +} + +int main(int argc, char** argv) { + // -s ../models/yolov5n-cls.wts ../models/yolov5n-cls.fp32.trt n + // -d ../models/yolov5n-cls.fp32.trt ../images + cudaSetDevice(kGpuId); + + std::string wts_name = ""; + std::string engine_name = ""; + float gd = 0.0f, gw = 0.0f; + std::string img_dir; + + if (!parse_args(argc, argv, wts_name, engine_name, gd, gw, img_dir)) { + std::cerr << "arguments not right!" << std::endl; + std::cerr << "./yolov5_cls -s [.wts] [.engine] [n/s/m/l/x or c gd gw] // serialize model to plan file" + << std::endl; + std::cerr << "./yolov5_cls -d [.engine] ../images // deserialize plan file and run inference" << std::endl; + return -1; + } + + // Create a model using the API directly and serialize it to a file + if (!wts_name.empty()) { + serialize_engine(kBatchSize, gd, gw, wts_name, engine_name); + return 0; + } + + // Deserialize the engine from file + IRuntime* runtime = nullptr; + ICudaEngine* engine = nullptr; + IExecutionContext* context = nullptr; + deserialize_engine(engine_name, &runtime, &engine, &context); + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + + // Prepare cpu and gpu buffers + float* gpu_buffers[2]; + float* cpu_input_buffer = nullptr; + float* cpu_output_buffer = nullptr; + prepare_buffers(engine, &gpu_buffers[0], &gpu_buffers[1], &cpu_input_buffer, &cpu_output_buffer); + + // Read images from directory + std::vector file_names; + if (read_files_in_dir(img_dir.c_str(), file_names) < 0) { + std::cerr << "read_files_in_dir failed." << std::endl; + return -1; + } + + // Read imagenet labels + auto classes = read_classes("./imagenet_classes.txt"); + + // batch predict + for (size_t i = 0; i < file_names.size(); i += kBatchSize) { + // Get a batch of images + std::vector img_batch; + std::vector img_name_batch; + for (size_t j = i; j < i + kBatchSize && j < file_names.size(); j++) { + cv::Mat img = cv::imread(img_dir + "/" + file_names[j]); + img_batch.push_back(img); + img_name_batch.push_back(file_names[j]); + } + + // Preprocess + batch_preprocess(img_batch, cpu_input_buffer); + + // Run inference + auto start = std::chrono::system_clock::now(); + infer(*context, stream, (void**)gpu_buffers, cpu_input_buffer, cpu_output_buffer, kBatchSize); + auto end = std::chrono::system_clock::now(); + std::cout << "inference time: " << std::chrono::duration_cast(end - start).count() + << "ms" << std::endl; + + // Postprocess and get top-k result + for (size_t b = 0; b < img_name_batch.size(); b++) { + float* p = &cpu_output_buffer[b * kOutputSize]; + auto res = softmax(p, kOutputSize); + auto topk_idx = topk(res, 3); + std::cout << img_name_batch[b] << std::endl; + for (auto idx : topk_idx) { + std::cout << " " << classes[idx] << " " << res[idx] << std::endl; + } + } + } + + // Release stream and buffers + cudaStreamDestroy(stream); + CUDA_CHECK(cudaFree(gpu_buffers[0])); + CUDA_CHECK(cudaFree(gpu_buffers[1])); + delete[] cpu_input_buffer; + delete[] cpu_output_buffer; + // Destroy the engine + delete context; + delete engine; + delete runtime; + + return 0; +} diff --git a/yolov5/yolov5_trt10/yolov5_cls_trt.py b/yolov5/yolov5_trt10/yolov5_cls_trt.py new file mode 100644 index 00000000..c36e25cd --- /dev/null +++ b/yolov5/yolov5_trt10/yolov5_cls_trt.py @@ -0,0 +1,261 @@ +""" +An example that uses TensorRT's Python api to make inferences. +""" +import os +import shutil +import sys +import threading +import time +import cv2 +import numpy as np +import torch +import pycuda.driver as cuda +import tensorrt as trt + + +def get_img_path_batches(batch_size, img_dir): + ret = [] + batch = [] + for root, dirs, files in os.walk(img_dir): + for name in files: + if len(batch) == batch_size: + ret.append(batch) + batch = [] + batch.append(os.path.join(root, name)) + if len(batch) > 0: + ret.append(batch) + return ret + + +with open("build/imagenet_classes.txt") as f: + classes = [line.strip() for line in f.readlines()] + + +class YoLov5TRT(object): + """ + description: A YOLOv5 class that warps TensorRT ops, preprocess and postprocess ops. + """ + + def __init__(self, engine_file_path): + # Create a Context on this device, + self.ctx = cuda.Device(0).make_context() + stream = cuda.Stream() + TRT_LOGGER = trt.Logger(trt.Logger.INFO) + runtime = trt.Runtime(TRT_LOGGER) + + # Deserialize the engine from file + with open(engine_file_path, "rb") as f: + engine = runtime.deserialize_cuda_engine(f.read()) + context = engine.create_execution_context() + + host_inputs = [] + cuda_inputs = [] + host_outputs = [] + cuda_outputs = [] + input_binding_names = [] + output_binding_names = [] + self.mean = (0.485, 0.456, 0.406) + self.std = (0.229, 0.224, 0.225) + + for binding_name in engine: + shape = engine.get_tensor_shape(binding_name) + print('binding_name:', binding_name, shape) + size = trt.volume(shape) + dtype = trt.nptype(engine.get_tensor_dtype(binding_name)) + # Allocate host and device buffers + host_mem = cuda.pagelocked_empty(size, dtype) + cuda_mem = cuda.mem_alloc(host_mem.nbytes) + # Append the device buffer to device bindings name. + # Append to the appropriate list. + if engine.get_tensor_mode(binding_name) == trt.TensorIOMode.INPUT: + input_binding_names.append(binding_name) + self.input_w = shape[-1] + self.input_h = shape[-2] + host_inputs.append(host_mem) + cuda_inputs.append(cuda_mem) + elif engine.get_tensor_mode(binding_name) == trt.TensorIOMode.OUTPUT: + output_binding_names.append(binding_name) + host_outputs.append(host_mem) + cuda_outputs.append(cuda_mem) + else: + print('unknow:', binding_name) + + # Store + self.stream = stream + self.context = context + self.engine = engine + self.host_inputs = host_inputs + self.cuda_inputs = cuda_inputs + self.host_outputs = host_outputs + self.cuda_outputs = cuda_outputs + self.input_binding_names = input_binding_names + self.output_binding_names = output_binding_names + self.batch_size = engine.get_tensor_shape(input_binding_names[0])[0] + print('batch_size:', self.batch_size) + + def infer(self, raw_image_generator): + threading.Thread.__init__(self) + # Make self the active context, pushing it on top of the context stack. + self.ctx.push() + # Restore + stream = self.stream + context = self.context + host_inputs = self.host_inputs + cuda_inputs = self.cuda_inputs + host_outputs = self.host_outputs + cuda_outputs = self.cuda_outputs + input_binding_names = self.input_binding_names + output_binding_names = self.output_binding_names + # Do image preprocess + batch_image_raw = [] + batch_input_image = np.empty( + shape=[self.batch_size, 3, self.input_h, self.input_w]) + for i, image_raw in enumerate(raw_image_generator): + batch_image_raw.append(image_raw) + input_image = self.preprocess_cls_image(image_raw) + np.copyto(batch_input_image[i], input_image) + batch_input_image = np.ascontiguousarray(batch_input_image) + + # Copy input image to host buffer + np.copyto(host_inputs[0], batch_input_image.ravel()) + start = time.time() + # Transfer input data to the GPU. + cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream) + # Run inference. + context.set_tensor_address(input_binding_names[0], cuda_inputs[0]) + context.set_tensor_address(output_binding_names[0], cuda_outputs[0]) + context.execute_async_v3(stream_handle=stream.handle) + # context.execute_async(batch_size=self.batch_size, + # bindings=bindings, stream_handle=stream.handle) + # # Transfer predictions back from the GPU. + cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream) + # # Synchronize the stream + stream.synchronize() + end = time.time() + # Remove any context from the top of the context stack, deactivating it. + self.ctx.pop() + # Here we use the first row of output in that batch_size = 1 + output = host_outputs[0] + # Do postprocess + for i in range(self.batch_size): + classes_ls, predicted_conf_ls, category_id_ls = self.postprocess_cls( + output) + cv2.putText(batch_image_raw[i], str( + classes_ls), (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 1, cv2.LINE_AA) + print(classes_ls, predicted_conf_ls) + return batch_image_raw, end - start + + def destroy(self): + # Remove any context from the top of the context stack, deactivating it. + self.ctx.pop() + + def get_raw_image(self, image_path_batch): + """ + description: Read an image from image path + """ + for img_path in image_path_batch: + yield cv2.imread(img_path) + + def get_raw_image_zeros(self, image_path_batch=None): + """ + description: Ready data for warmup + """ + for _ in range(self.batch_size): + yield np.zeros([self.input_h, self.input_w, 3], dtype=np.uint8) + + def preprocess_cls_image(self, input_img): + im = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB) + imh, imw = im.shape[:2] + m = min(imh, imw) + top = (imh - m) // 2 + left = (imw - m) // 2 + crop = im[top:top + m, left:left + m] + im = cv2.resize(crop, (self.input_h, self.input_w)) + im = np.float32(im) + im /= 255.0 + im -= self.mean + im /= self.std + im = im.transpose(2, 0, 1) + # prepare batch + batch_data = np.expand_dims(im, axis=0) + return batch_data + + def postprocess_cls(self, output_data): + classes_ls = [] + predicted_conf_ls = [] + category_id_ls = [] + output_data = output_data.reshape(self.batch_size, -1) + output_data = torch.Tensor(output_data) + p = torch.nn.functional.softmax(output_data, dim=1) + score, index = torch.topk(p, 3) + for ind in range(index.shape[0]): + input_category_id = index[ind][0].item() # 716 + category_id_ls.append(input_category_id) + predicted_confidence = score[ind][0].item() + predicted_conf_ls.append(predicted_confidence) + classes_ls.append(classes[input_category_id]) + return classes_ls, predicted_conf_ls, category_id_ls + + +class inferThread(threading.Thread): + def __init__(self, yolov5_wrapper, image_path_batch): + threading.Thread.__init__(self) + self.yolov5_wrapper = yolov5_wrapper + self.image_path_batch = image_path_batch + + def run(self): + batch_image_raw, use_time = self.yolov5_wrapper.infer( + self.yolov5_wrapper.get_raw_image(self.image_path_batch)) + for i, img_path in enumerate(self.image_path_batch): + parent, filename = os.path.split(img_path) + save_name = os.path.join('output', filename) + # Save image + cv2.imwrite(save_name, batch_image_raw[i]) + print('input->{}, time->{:.2f}ms, saving into output/'.format( + self.image_path_batch, use_time * 1000)) + + +class warmUpThread(threading.Thread): + def __init__(self, yolov5_wrapper): + threading.Thread.__init__(self) + self.yolov5_wrapper = yolov5_wrapper + + def run(self): + batch_image_raw, use_time = self.yolov5_wrapper.infer( + self.yolov5_wrapper.get_raw_image_zeros()) + print( + 'warm_up->{}, time->{:.2f}ms'.format(batch_image_raw[0].shape, use_time * 1000)) + + +if __name__ == "__main__": + # load custom plugin and engine + engine_file_path = "build/yolov5n-cls.fp16.trt" + + if len(sys.argv) > 1: + engine_file_path = sys.argv[1] + + if os.path.exists('output/'): + shutil.rmtree('output/') + os.makedirs('output/') + # a YoLov5TRT instance + yolov5_wrapper = YoLov5TRT(engine_file_path) + try: + print('batch size is', yolov5_wrapper.batch_size) + + image_dir = "../images/" + image_path_batches = get_img_path_batches( + yolov5_wrapper.batch_size, image_dir) + + for i in range(10): + # create a new thread to do warm_up + thread1 = warmUpThread(yolov5_wrapper) + thread1.start() + thread1.join() + for batch in image_path_batches: + # create a new thread to do inference + thread1 = inferThread(yolov5_wrapper, batch) + thread1.start() + thread1.join() + finally: + # destroy the instance + yolov5_wrapper.destroy()