Skip to content

Commit

Permalink
fix pre-processing
Browse files Browse the repository at this point in the history
Signed-off-by: jagadeesh <[email protected]>
  • Loading branch information
jagadeesh committed Aug 9, 2023
1 parent a781796 commit ec7e7f6
Showing 1 changed file with 25 additions and 26 deletions.
51 changes: 25 additions & 26 deletions cpp/src/examples/image_classifier/resnet-18/resnet-18_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,45 +71,44 @@ std::vector<torch::jit::IValue> ResnetHandler::Preprocess(

cv::Mat image = cv::imdecode(data_it->second, cv::IMREAD_COLOR);

cv::cvtColor(image, image, cv::COLOR_BGR2RGB);

// Check if the image was successfully decoded
if (image.empty()) {
std::cerr << "Failed to decode the image." << std::endl;
}

// Resize
const int newWidth = 256, newHeight = 256;
cv::Mat resizedImage;
cv::resize(image, resizedImage, cv::Size(newWidth, newHeight));

// Crop image
const int cropSize = 224;
const int offsetW = (resizedImage.cols - cropSize) / 2;
const int offsetH = (resizedImage.rows - cropSize) / 2;
const int rows = image.rows;
const int cols = image.cols;

const int cropSize = std::min(rows, cols);
const int offsetW = (cols - cropSize) / 2;
const int offsetH = (rows - cropSize) / 2;

const cv::Rect roi(offsetW, offsetH, cropSize, cropSize);
cv::Mat croppedImage = resizedImage(roi).clone();
image = image(roi);

// Convert the OpenCV image to a torch tensor
// Drift in cropped image
// Vision Crop: 114, 118, 115, 102, 106, 97
// OpenCV Crop: 113, 118, 114, 100, 106, 97
torch::TensorOptions options(torch::kByte);
torch::Tensor tensorImage = torch::from_blob(
croppedImage.data,
{croppedImage.rows, croppedImage.cols, croppedImage.channels()},
options);
// Resize
cv::resize(image, image, cv::Size(224, 224));

// Convert BGR to RGB format
cv::cvtColor(image, image, cv::COLOR_BGR2RGB);

image.convertTo(image, CV_32FC3, 1 / 255.0);

// Convert the OpenCV image to a torch tensor
torch::Tensor tensorImage = torch::from_blob(image.data, {image.rows, image.cols, 3}, c10::kFloat);
tensorImage = tensorImage.permute({2, 0, 1});
tensorImage = tensorImage.to(torch::kFloat32) / 255.0;
tensorImage.unsqueeze_(0);

// Normalize
torch::Tensor normalizedTensorImage =
torch::data::transforms::Normalize<>(
{0.485, 0.456, 0.406}, {0.229, 0.224, 0.225})(tensorImage);
normalizedTensorImage.clone();
batch_tensors.emplace_back(normalizedTensorImage.to(*device));
std::vector<double> norm_mean = {0.485, 0.456, 0.406};
std::vector<double> norm_std = {0.229, 0.224, 0.225};

tensorImage =
torch::data::transforms::Normalize<>(norm_mean, norm_std)(tensorImage);

tensorImage.clone();
batch_tensors.emplace_back(tensorImage.to(*device));
idx_to_req_id.second[idx++] = request.request_id;
} else if (dtype_it->second == "List") {
// case3: the image is a list
Expand Down

0 comments on commit ec7e7f6

Please sign in to comment.