Skip to content

Commit

Permalink
stripped post processor for torch exporting
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanvdpalen committed Oct 2, 2024
1 parent 12146d6 commit 83dde05
Showing 1 changed file with 40 additions and 31 deletions.
71 changes: 40 additions & 31 deletions rtdetrv2_pytorch/src/nn/postprocessor/nms_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.distributed
import torchvision
from torch import Tensor
from typing import List, Dict

from ...core import register

Expand Down Expand Up @@ -34,10 +35,10 @@ def __init__(self, \
self.deploy_mode = False
self.image_dimensions = image_dimensions

def forward(self, outputs: Dict[str, Tensor], **kwargs):
iou_threshold = kwargs.get('iou_threshold', self.iou_threshold)
score_threshold = kwargs.get('score_threshold', self.score_threshold)
keep_topk = kwargs.get('keep_topk', self.keep_topk)
def forward(self, outputs: Dict[str, Tensor], **kwargs) -> List[Dict[str, Tensor]]:
#iou_threshold = kwargs.get('iou_threshold', self.iou_threshold)
#score_threshold = kwargs.get('score_threshold', self.score_threshold)
#keep_topk = kwargs.get('keep_topk', self.keep_topk)
logits, boxes = outputs['pred_logits'], outputs['pred_boxes']
patch_size = torch.tensor(self.image_dimensions, dtype=torch.float32).to(boxes.device)
patch_size = patch_size.repeat(boxes.size(0), 1) # Repeat for batch size
Expand All @@ -53,33 +54,41 @@ def forward(self, outputs: Dict[str, Tensor], **kwargs):
pred_scores = values

# TODO for onnx export
if self.deploy_mode:
blobs = {
'pred_labels': pred_labels,
'pred_boxes': pred_boxes,
'pred_scores': pred_scores
}
return blobs

results = []
for i in range(logits.shape[0]):
score_keep = pred_scores[i] > score_threshold
pred_box = pred_boxes[i][score_keep]
pred_label = pred_labels[i][score_keep]
pred_score = pred_scores[i][score_keep]

keep = torchvision.ops.batched_nms(pred_box, pred_score, pred_label, iou_threshold)
keep = keep[:keep_topk]

blob = {
'labels': pred_label[keep],
'boxes': pred_box[keep],
'scores': pred_score[keep],
}

results.append(blob)

return results
blobs = {
'labels': pred_labels,
'boxes': pred_boxes,
'scores': pred_scores
}
return blobs

# results = []
# for i in range(logits.shape[0]):
# score_keep = pred_scores[i] > score_threshold
# pred_box = pred_boxes[i][score_keep]
# pred_label = pred_labels[i][score_keep]
# pred_score = pred_scores[i][score_keep]

# keep = torchvision.ops.batched_nms(pred_box, pred_score, pred_label, iou_threshold)
# keep = keep[:keep_topk]

# blob = {
# 'labels': pred_label[keep],
# 'boxes': pred_box[keep],
# 'scores': pred_score[keep],
# }

# results.append(blob)

# # Add debug logs
# print("results")
# print(type(results))
# for idx, result in enumerate(results):
# print(f"Type of results[{idx}]:", type(result))
# print(f"Keys in results[{idx}]:", result.keys())
# for key in result:
# print(f"Type of results[{idx}]['{key}']:", type(result[key]))

# return blobs

def deploy(self, ):
self.eval()
Expand Down

0 comments on commit 83dde05

Please sign in to comment.