From 68673167b6d99d4297581c55970c289c4685c745 Mon Sep 17 00:00:00 2001 From: John Date: Sat, 20 Jan 2024 12:56:01 -0500 Subject: [PATCH] Cleaning up --- src/perception/object_detector/inference.cu | 14 ++++++-------- src/perception/object_detector/inference.cuh | 4 ++-- .../object_detector/inference_wrapper.cu | 2 -- src/perception/object_detector/logger.cuh | 2 +- .../object_detector/object_detector.processing.cpp | 12 +++--------- 5 files changed, 12 insertions(+), 22 deletions(-) diff --git a/src/perception/object_detector/inference.cu b/src/perception/object_detector/inference.cu index 82ce6f622..2e9a40a60 100644 --- a/src/perception/object_detector/inference.cu +++ b/src/perception/object_detector/inference.cu @@ -1,8 +1,6 @@ #include "inference.cuh" #include -#include -#include #include #include #include @@ -26,7 +24,7 @@ namespace mrover { Inference::Inference(std::filesystem::path const& onnxModelPath) { //Create the engine object from either the file or from onnx file - mEngine = std::unique_ptr>{createCudaEngine(onnxModelPath)}; + mEngine = std::unique_ptr{createCudaEngine(onnxModelPath)}; if (!mEngine) throw std::runtime_error("Failed to create CUDA engine"); mLogger.log(ILogger::Severity::kINFO, "Created CUDA Engine"); @@ -48,22 +46,22 @@ namespace mrover { constexpr auto explicitBatch = 1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); //Init logger - std::unique_ptr> builder{createInferBuilder(mLogger)}; + std::unique_ptr builder{createInferBuilder(mLogger)}; if (!builder) throw std::runtime_error("Failed to create Infer Builder"); mLogger.log(ILogger::Severity::kINFO, "Created Infer Builder"); //Init Network - std::unique_ptr> network{builder->createNetworkV2(explicitBatch)}; + std::unique_ptr network{builder->createNetworkV2(explicitBatch)}; if (!network) throw std::runtime_error("Failed to create Network Definition"); mLogger.log(ILogger::Severity::kINFO, "Created Network Definition"); //Init the onnx file parser - std::unique_ptr> parser{nvonnxparser::createParser(*network, mLogger)}; + std::unique_ptr parser{nvonnxparser::createParser(*network, mLogger)}; if (!parser) throw std::runtime_error("Failed to create ONNX Parser"); mLogger.log(ILogger::Severity::kINFO, "Created ONNX Parser"); //Init the builder - std::unique_ptr> config{builder->createBuilderConfig()}; + std::unique_ptr config{builder->createBuilderConfig()}; if (!config) throw std::runtime_error("Failed to create Builder Config"); mLogger.log(ILogger::Severity::kINFO, "Created Builder Config"); @@ -75,7 +73,7 @@ namespace mrover { //Create runtime engine IRuntime* runtime = createInferRuntime(mLogger); - // TODO: use relative to package + //Define the engine file location relative to the mrover package std::filesystem::path packagePath = ros::package::getPath("mrover"); std::filesystem::path enginePath = packagePath / "data" / "tensorrt-engine-best.engine"; diff --git a/src/perception/object_detector/inference.cuh b/src/perception/object_detector/inference.cuh index 11df55c64..f7c113488 100644 --- a/src/perception/object_detector/inference.cuh +++ b/src/perception/object_detector/inference.cuh @@ -17,10 +17,10 @@ namespace mrover { nvinfer1::Logger mLogger; //Ptr to the engine - std::unique_ptr> mEngine{}; + std::unique_ptr mEngine{}; //Ptr to the context - std::unique_ptr> mContext{}; + std::unique_ptr mContext{}; //Input, output and reference tensors cv::Mat mInputTensor; diff --git a/src/perception/object_detector/inference_wrapper.cu b/src/perception/object_detector/inference_wrapper.cu index 271a51252..71cbb91b2 100644 --- a/src/perception/object_detector/inference_wrapper.cu +++ b/src/perception/object_detector/inference_wrapper.cu @@ -1,6 +1,5 @@ #include "inference_wrapper.hpp" -#include #include #include "inference.cuh" @@ -30,7 +29,6 @@ namespace mrover { } cv::Mat InferenceWrapper::getOutputTensor() const { - //Get the output tensor from the inference object return mInference->getOutputTensor(); } diff --git a/src/perception/object_detector/logger.cuh b/src/perception/object_detector/logger.cuh index 29db705c0..02bf62d1d 100644 --- a/src/perception/object_detector/logger.cuh +++ b/src/perception/object_detector/logger.cuh @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace nvinfer1 { diff --git a/src/perception/object_detector/object_detector.processing.cpp b/src/perception/object_detector/object_detector.processing.cpp index 3b253133c..64bd2ad7e 100644 --- a/src/perception/object_detector/object_detector.processing.cpp +++ b/src/perception/object_detector/object_detector.processing.cpp @@ -53,6 +53,7 @@ namespace mrover { int rows = output.rows; int dimensions = output.cols; bool yolov8 = false; + // yolov5 has an output of shape (batchSize, 25200, 85) (Num classes + box[x,y,w,h] + confidence[c]) // yolov8 has an output of shape (batchSize, 84, 8400) (Num classes + box[x,y,w,h]) if (dimensions > rows) // Check if the shape[2] is more than shape[1] (yolov8) @@ -123,7 +124,7 @@ namespace mrover { boxes.emplace_back(left, top, width, height); } } else { //YOLO V5 - throw std::runtime_error("Something is wrong with interpretation"); + throw std::runtime_error("Something is wrong Model with interpretation"); } data += dimensions; @@ -156,10 +157,7 @@ namespace mrover { //Get the first detection to locate in 3D Detection firstDetection = detections[0]; - //Get the associated confidence with the object float classConfidence = firstDetection.confidence; - - //Get the associated box with the object cv::Rect box = firstDetection.box; //Fill out the msgData information to be published to the topic @@ -170,10 +168,9 @@ namespace mrover { // msgData.width = static_cast(box.width); // msgData.height = static_cast(box.height); - //Calculate the center of the box std::pair center(box.x + box.width/2, box.y + box.height/2); - //Get the object's position in 3D from the point cloud + //Get the object's position in 3D from the point cloud and run this statement if the optional has a value if (std::optional objectLocation = getObjectInCamFromPixel(msg, center.first * static_cast(msg->width) / sizedImage.cols, center.second * static_cast(msg->height) / sizedImage.rows, box.width, box.height); objectLocation) { try{ //Publish Immediate @@ -233,9 +230,6 @@ namespace mrover { } - //DEBUG TODO REMOVE - ROS_INFO_STREAM(std::format("Hit count: {}", mHitCount)); - if (mDebugImgPub.getNumSubscribers() > 0 || true) { // Init sensor msg image sensor_msgs::Image newDebugImageMessage;//I chose regular msg not ptr so it can be used outside of this process