Skip to content

Commit

Permalink
TensorRT10 support for YOLOv10
Browse files Browse the repository at this point in the history
  • Loading branch information
mpj1234 committed Jul 24, 2024
1 parent ca507e8 commit 070c81a
Show file tree
Hide file tree
Showing 27 changed files with 4,095 additions and 0 deletions.
47 changes: 47 additions & 0 deletions yolov10/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
cmake_minimum_required(VERSION 3.10)

project(yolov10)

add_definitions(-std=c++11)
add_definitions(-DAPI_EXPORTS)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_BUILD_TYPE Debug)

set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc)
enable_language(CUDA)

include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_SOURCE_DIR}/plugin)

# include and link dirs of cuda and tensorrt, you need adapt them if yours are different
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
message("embed_platform on")
include_directories(/usr/local/cuda/targets/aarch64-linux/include)
link_directories(/usr/local/cuda/targets/aarch64-linux/lib)
else()
message("embed_platform off")

# cuda
include_directories(/usr/local/cuda/include)
link_directories(/usr/local/cuda/lib64)

# tensorrt
include_directories(/workspace/shared/TensorRT-10.2.0.19/include/)
link_directories(/workspace/shared/TensorRT-10.2.0.19/lib/)

# include_directories(/home/lindsay/TensorRT-7.2.3.4/include)
# link_directories(/home/lindsay/TensorRT-7.2.3.4/lib)
endif()

add_library(myplugins SHARED ${PROJECT_SOURCE_DIR}/plugin/yololayer.cu)
target_link_libraries(myplugins nvinfer cudart)

find_package(OpenCV)
include_directories(${OpenCV_INCLUDE_DIRS})

file(GLOB_RECURSE SRCS ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/*.cu)
add_executable(yolov10_det ${PROJECT_SOURCE_DIR}/yolov10_det.cpp ${SRCS})
target_link_libraries(yolov10_det nvinfer)
target_link_libraries(yolov10_det cudart)
target_link_libraries(yolov10_det myplugins)
target_link_libraries(yolov10_det ${OpenCV_LIBS})
71 changes: 71 additions & 0 deletions yolov10/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
## Introduce

Yolov10 model supports TensorRT-10.

## Environment

CUDA: 11.8
CUDNN: 8.9.1.23
TensorRT: TensorRT-10.2.0.19

## Support

* [x] YOLOv10-det support FP32/FP16/INT8 and Python/C++ API

## Config

* Choose the YOLOv10 sub-model n/s/m/b/l/x from command line arguments.
* Other configs please check [src/config.h](src/config.h)

## Build and Run

1. generate .wts from pytorch with .pt, or download .wts from model zoo

```shell
git clone https://github.com/THU-MIG/yolov10.git
cd yolov10/
wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10n.pt

git clone -b trt10 https://github.com/wang-xinyu/tensorrtx.git
cp [PATH-TO-TENSORRTX]/yolov10/gen_wts.py .

python gen_wts.py -w yolov10n.pt -o yolov10n.wts
# A file 'yolov10n.wts' will be generated.
```

2. build tensorrtx/yolov10 and run

#### Detection

```shell
cd [PATH-TO-TENSORRTX]/yolov10
# Update kNumClass in src/config.h if your model is trained on custom dataset
mkdir build
cd build
cp [PATH-TO-yolov10]/yolov10n.wts .
cmake ..
make

# Build and serialize TensorRT engine
./yolov10_det -s yolov10n.wts yolov10n.engine [n/s/m/b/l/x]

# Run inference
./yolov10_det -d yolov10n.engine ../images
# The results are displayed in the console
```

3. Optional, load and run the tensorrt model in Python
```shell
// Install python-tensorrt, pycuda, etc.
// Ensure the yolov10n.engine
python yolov10_det_trt.py ./build/yolov10n.engine ./build/libmyplugins.so
```

## INT8 Quantization
1. Prepare calibration images, you can randomly select 1000s images from your train set. For coco, you can also download my calibration images `coco_calib` from [GoogleDrive](https://drive.google.com/drive/folders/1s7jE9DtOngZMzJC1uL307J2MiaGwdRSI?usp=sharing) or [BaiduPan](https://pan.baidu.com/s/1GOm_-JobpyLMAqZWCDUhKg) pwd: a9wh
2. unzip it in yolov10/build
3. set the macro `USE_INT8` in src/config.h and make again
4. serialize the model and test

## More Information
See the readme in [home page.](https://github.com/wang-xinyu/tensorrtx)
56 changes: 56 additions & 0 deletions yolov10/gen_wts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# -*- coding: UTF-8 -*-
"""
@Author: mpj
@Date : 2024/7/22 下午9:17
@version V1.0
"""
import sys # noqa: F401
import argparse
import os
import struct
import torch


def parse_args():
parser = argparse.ArgumentParser(description='Convert .pt file to .wts')
parser.add_argument('-w', '--weights', default='./weights/yolov10n.pt',
help='Input weights (.pt) file path (required)')
parser.add_argument(
'-o', '--output', help='Output (.wts) file path (optional)')
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit('Invalid input file')
if not args.output:
args.output = os.path.splitext(args.weights)[0] + '.wts'
elif os.path.isdir(args.output):
args.output = os.path.join(
args.output,
os.path.splitext(os.path.basename(args.weights))[0] + '.wts')
return args.weights, args.output


pt_file, wts_file = parse_args()

# Load model
print(f'Loading {pt_file}')

# Initialize
device = 'cpu'

# Load model
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
# If the training is not finished, the model will be interrupted.
# model = torch.load(pt_file, map_location=device)['ema'].float() # load to FP32

model.to(device).eval()

with open(wts_file, 'w') as f:
f.write('{}\n'.format(len(model.state_dict().keys())))
for k, v in model.state_dict().items():
vr = v.reshape(-1).cpu().numpy()
f.write('{} {} '.format(k, len(vr)))
for vv in vr:
f.write(' ')
f.write(struct.pack('>f', float(vv)).hex())
f.write('\n')
print(f'success {wts_file}!!!')
Binary file added yolov10/images/bus.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added yolov10/images/cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added yolov10/images/dog.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added yolov10/images/zidane.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
44 changes: 44 additions & 0 deletions yolov10/include/block.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#pragma once

#include <map>
#include <string>
#include <vector>
#include "NvInfer.h"

std::map<std::string, nvinfer1::Weights> loadWeights(const std::string file);

nvinfer1::IScaleLayer* addBatchNorm2d(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input,
std::string lname, float eps);

nvinfer1::IElementWiseLayer* convBnSiLU(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input,
int ch, int k, int s, std::string lname, int g = 1);

nvinfer1::IElementWiseLayer* C2F(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input, int c1,
int c2, int n, bool shortcut, float e, std::string lname);

nvinfer1::IElementWiseLayer* C2(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights>& weightMap, nvinfer1::ITensor& input, int c1,
int c2, int n, bool shortcut, float e, std::string lname);

nvinfer1::IElementWiseLayer* SPPF(nvinfer1::INetworkDefinition* network,
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor& input, int c1,
int c2, int k, std::string lname);

nvinfer1::IShuffleLayer* DFL(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights> weightMap,
nvinfer1::ITensor& input, int ch, int grid, int k, int s, int p, std::string lname);

nvinfer1::IPluginV2Layer* addYoLoLayer(nvinfer1::INetworkDefinition* network, std::vector<nvinfer1::ILayer*> dets,
const int* px_arry, int px_arry_num);

nvinfer1::ILayer* SCDown(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights> weightMap,
nvinfer1::ITensor& input, int ch, int k, int s, std::string lname);

nvinfer1::ILayer* PSA(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights> weightMap,
nvinfer1::ITensor& input, int ch, std::string lname);

nvinfer1::ILayer* C2fCIB(nvinfer1::INetworkDefinition* network, std::map<std::string, nvinfer1::Weights> weightMap,
nvinfer1::ITensor& input, int c1, int c2, int n, bool shortcut, bool lk, float e,
std::string lname);
39 changes: 39 additions & 0 deletions yolov10/include/calibrator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef ENTROPY_CALIBRATOR_H
#define ENTROPY_CALIBRATOR_H

#include <NvInfer.h>
#include <string>
#include <vector>
#include "macros.h"

//! \class Int8EntropyCalibrator2
//!
//! \brief Implements Entropy calibrator 2.
//! CalibrationAlgoType is kENTROPY_CALIBRATION_2.
//!
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 {
public:
Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name,
const char* input_blob_name, bool read_cache = true);
virtual ~Int8EntropyCalibrator2();
int getBatchSize() const TRT_NOEXCEPT override;
bool getBatch(void* bindings[], const char* names[], int nbBindings) TRT_NOEXCEPT override;
const void* readCalibrationCache(size_t& length) TRT_NOEXCEPT override;
void writeCalibrationCache(const void* cache, size_t length) TRT_NOEXCEPT override;

private:
int batchsize_;
int input_w_;
int input_h_;
int img_idx_;
std::string img_dir_;
std::vector<std::string> img_files_;
size_t input_count_;
std::string calib_table_name_;
const char* input_blob_name_;
bool read_cache_;
void* device_input_;
std::vector<char> calib_cache_;
};

#endif // ENTROPY_CALIBRATOR_H
16 changes: 16 additions & 0 deletions yolov10/include/config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//#define USE_FP32
#define USE_FP16
// #define USE_INT8

const static char* kInputTensorName = "images";
const static char* kOutputTensorName = "output";
const static int kNumClass = 80;
const static int kBatchSize = 1;
const static int kGpuId = 0;
const static int kInputH = 640;
const static int kInputW = 640;
const static float kConfThresh = 0.5f;
const static int kMaxInputImageSize = 3000 * 3000;
const static int kMaxNumOutputBbox = 1000;
//Quantization input image folder path
const static char* kInputQuantizationFolder = "./coco_calib";
17 changes: 17 additions & 0 deletions yolov10/include/cuda_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef TRTX_CUDA_UTILS_H_
#define TRTX_CUDA_UTILS_H_

#include <cuda_runtime_api.h>

#ifndef CUDA_CHECK
#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#endif // CUDA_CHECK

#endif // TRTX_CUDA_UTILS_H_
Loading

0 comments on commit 070c81a

Please sign in to comment.