Skip to content

Commit

Permalink
Cleaning up
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrhm committed Jan 20, 2024
1 parent 656f357 commit 6867316
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 22 deletions.
14 changes: 6 additions & 8 deletions src/perception/object_detector/inference.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include "inference.cuh"

#include <NvInfer.h>
#include <NvInferRuntime.h>
#include <NvInferRuntimeBase.h>
#include <NvOnnxParser.h>
#include <cstdio>
#include <cuda_runtime_api.h>
Expand All @@ -26,7 +24,7 @@ namespace mrover {

Inference::Inference(std::filesystem::path const& onnxModelPath) {
//Create the engine object from either the file or from onnx file
mEngine = std::unique_ptr<ICudaEngine, Destroy<ICudaEngine>>{createCudaEngine(onnxModelPath)};
mEngine = std::unique_ptr<ICudaEngine>{createCudaEngine(onnxModelPath)};
if (!mEngine) throw std::runtime_error("Failed to create CUDA engine");

mLogger.log(ILogger::Severity::kINFO, "Created CUDA Engine");
Expand All @@ -48,22 +46,22 @@ namespace mrover {
constexpr auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);

//Init logger
std::unique_ptr<IBuilder, Destroy<IBuilder>> builder{createInferBuilder(mLogger)};
std::unique_ptr<IBuilder> builder{createInferBuilder(mLogger)};
if (!builder) throw std::runtime_error("Failed to create Infer Builder");
mLogger.log(ILogger::Severity::kINFO, "Created Infer Builder");

//Init Network
std::unique_ptr<INetworkDefinition, Destroy<INetworkDefinition>> network{builder->createNetworkV2(explicitBatch)};
std::unique_ptr<INetworkDefinition> network{builder->createNetworkV2(explicitBatch)};
if (!network) throw std::runtime_error("Failed to create Network Definition");
mLogger.log(ILogger::Severity::kINFO, "Created Network Definition");

//Init the onnx file parser
std::unique_ptr<nvonnxparser::IParser, Destroy<nvonnxparser::IParser>> parser{nvonnxparser::createParser(*network, mLogger)};
std::unique_ptr<nvonnxparser::IParser> parser{nvonnxparser::createParser(*network, mLogger)};
if (!parser) throw std::runtime_error("Failed to create ONNX Parser");
mLogger.log(ILogger::Severity::kINFO, "Created ONNX Parser");

//Init the builder
std::unique_ptr<IBuilderConfig, Destroy<IBuilderConfig>> config{builder->createBuilderConfig()};
std::unique_ptr<IBuilderConfig> config{builder->createBuilderConfig()};
if (!config) throw std::runtime_error("Failed to create Builder Config");
mLogger.log(ILogger::Severity::kINFO, "Created Builder Config");

Expand All @@ -75,7 +73,7 @@ namespace mrover {
//Create runtime engine
IRuntime* runtime = createInferRuntime(mLogger);

// TODO: use relative to package
//Define the engine file location relative to the mrover package
std::filesystem::path packagePath = ros::package::getPath("mrover");
std::filesystem::path enginePath = packagePath / "data" / "tensorrt-engine-best.engine";

Expand Down
4 changes: 2 additions & 2 deletions src/perception/object_detector/inference.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ namespace mrover {
nvinfer1::Logger mLogger;

//Ptr to the engine
std::unique_ptr<ICudaEngine, nvinfer1::Destroy<ICudaEngine>> mEngine{};
std::unique_ptr<ICudaEngine> mEngine{};

//Ptr to the context
std::unique_ptr<IExecutionContext, nvinfer1::Destroy<IExecutionContext>> mContext{};
std::unique_ptr<IExecutionContext> mContext{};

//Input, output and reference tensors
cv::Mat mInputTensor;
Expand Down
2 changes: 0 additions & 2 deletions src/perception/object_detector/inference_wrapper.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "inference_wrapper.hpp"

#include <NvInferRuntimeBase.h>
#include <opencv4/opencv2/core/mat.hpp>

#include "inference.cuh"
Expand Down Expand Up @@ -30,7 +29,6 @@ namespace mrover {
}

cv::Mat InferenceWrapper::getOutputTensor() const {
//Get the output tensor from the inference object
return mInference->getOutputTensor();
}

Expand Down
2 changes: 1 addition & 1 deletion src/perception/object_detector/logger.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <NvInferRuntimeBase.h>
#include <NvInfer.h>

namespace nvinfer1 {

Expand Down
12 changes: 3 additions & 9 deletions src/perception/object_detector/object_detector.processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ namespace mrover {
int rows = output.rows;
int dimensions = output.cols;
bool yolov8 = false;

// yolov5 has an output of shape (batchSize, 25200, 85) (Num classes + box[x,y,w,h] + confidence[c])
// yolov8 has an output of shape (batchSize, 84, 8400) (Num classes + box[x,y,w,h])
if (dimensions > rows) // Check if the shape[2] is more than shape[1] (yolov8)
Expand Down Expand Up @@ -123,7 +124,7 @@ namespace mrover {
boxes.emplace_back(left, top, width, height);
}
} else { //YOLO V5
throw std::runtime_error("Something is wrong with interpretation");
throw std::runtime_error("Something is wrong Model with interpretation");
}

data += dimensions;
Expand Down Expand Up @@ -156,10 +157,7 @@ namespace mrover {
//Get the first detection to locate in 3D
Detection firstDetection = detections[0];

//Get the associated confidence with the object
float classConfidence = firstDetection.confidence;

//Get the associated box with the object
cv::Rect box = firstDetection.box;

//Fill out the msgData information to be published to the topic
Expand All @@ -170,10 +168,9 @@ namespace mrover {
// msgData.width = static_cast<float>(box.width);
// msgData.height = static_cast<float>(box.height);

//Calculate the center of the box
std::pair center(box.x + box.width/2, box.y + box.height/2);

//Get the object's position in 3D from the point cloud
//Get the object's position in 3D from the point cloud and run this statement if the optional has a value
if (std::optional<SE3> objectLocation = getObjectInCamFromPixel(msg, center.first * static_cast<float>(msg->width) / sizedImage.cols, center.second * static_cast<float>(msg->height) / sizedImage.rows, box.width, box.height); objectLocation) {
try{
//Publish Immediate
Expand Down Expand Up @@ -233,9 +230,6 @@ namespace mrover {

}

//DEBUG TODO REMOVE
ROS_INFO_STREAM(std::format("Hit count: {}", mHitCount));

if (mDebugImgPub.getNumSubscribers() > 0 || true) {
// Init sensor msg image
sensor_msgs::Image newDebugImageMessage;//I chose regular msg not ptr so it can be used outside of this process
Expand Down

0 comments on commit 6867316

Please sign in to comment.