From d6186dc3a439a528e5324f821562e8b4939c71c3 Mon Sep 17 00:00:00 2001 From: Stefan Date: Tue, 24 Sep 2024 16:01:45 +0200 Subject: [PATCH] configuration of RT-detr --- hubconf.py | 11 +++++----- rtdetrv2_pytorch/src/core/yaml_config.py | 5 +++++ .../src/nn/postprocessor/nms_postprocessor.py | 22 +++++++++++++------ .../src/zoo/rtdetr/rtdetr_postprocessor.py | 22 ++++++++++++++----- .../src/zoo/rtdetr/rtdetrv2_criterion.py | 13 +++++++---- 5 files changed, 50 insertions(+), 23 deletions(-) diff --git a/hubconf.py b/hubconf.py index 27ff7925..903a09f0 100644 --- a/hubconf.py +++ b/hubconf.py @@ -27,10 +27,10 @@ def _load_checkpoint(path: str, map_location='cpu'): return state -def _build_model(args, ): +def _build_model(args, num_classes=80): """main """ - cfg = YAMLConfig(args.config) + cfg = YAMLConfig(args.config, num_classes=num_classes) if args.resume: checkpoint = _load_checkpoint(args.resume, map_location='cpu') @@ -47,14 +47,13 @@ class Model(nn.Module): def __init__(self, ) -> None: super().__init__() self.model = cfg.model.deploy() - self.postprocessor = cfg.postprocessor.deploy() + self.postprocessor = cfg.postprocessor - def forward(self, images, orig_target_sizes): + def forward(self, images): outputs = self.model(images) - outputs = self.postprocessor(outputs, orig_target_sizes) return outputs - return Model() + return cfg.model CONFIG = { diff --git a/rtdetrv2_pytorch/src/core/yaml_config.py b/rtdetrv2_pytorch/src/core/yaml_config.py index 3b6a46ee..4d07b1bd 100644 --- a/rtdetrv2_pytorch/src/core/yaml_config.py +++ b/rtdetrv2_pytorch/src/core/yaml_config.py @@ -18,6 +18,11 @@ def __init__(self, cfg_path: str, **kwargs) -> None: super().__init__() cfg = load_config(cfg_path) + + if 'num_classes' in kwargs: + cfg['num_classes'] = kwargs['num_classes'] + print(f"Overriding num_classes in cfg with {kwargs['num_classes']}") + cfg = merge_dict(cfg, kwargs) self.yaml_cfg = copy.deepcopy(cfg) diff --git a/rtdetrv2_pytorch/src/nn/postprocessor/nms_postprocessor.py b/rtdetrv2_pytorch/src/nn/postprocessor/nms_postprocessor.py index b0945946..e846895c 100644 --- a/rtdetrv2_pytorch/src/nn/postprocessor/nms_postprocessor.py +++ b/rtdetrv2_pytorch/src/nn/postprocessor/nms_postprocessor.py @@ -19,10 +19,11 @@ class DetNMSPostProcessor(torch.nn.Module): def __init__(self, \ iou_threshold=0.7, - score_threshold=0.01, + score_threshold=0.1, keep_topk=300, box_fmt='cxcywh', - logit_fmt='sigmoid') -> None: + logit_fmt='sigmoid', + image_dimensions=(640,640)) -> None: super().__init__() self.iou_threshold = iou_threshold self.score_threshold = score_threshold @@ -31,11 +32,18 @@ def __init__(self, \ self.logit_fmt = logit_fmt.lower() self.logit_func = getattr(F, self.logit_fmt, None) self.deploy_mode = False + self.image_dimensions = image_dimensions - def forward(self, outputs: Dict[str, Tensor], orig_target_sizes: Tensor): + def forward(self, outputs: Dict[str, Tensor], **kwargs): + iou_threshold = kwargs.get('iou_threshold', self.iou_threshold) + score_threshold = kwargs.get('score_threshold', self.score_threshold) + keep_topk = kwargs.get('keep_topk', self.keep_topk) logits, boxes = outputs['pred_logits'], outputs['pred_boxes'] + patch_size = torch.tensor(self.image_dimensions, dtype=torch.float32).to(boxes.device) + patch_size = patch_size.repeat(boxes.size(0), 1) # Repeat for batch size + pred_boxes = torchvision.ops.box_convert(boxes, in_fmt=self.box_fmt, out_fmt='xyxy') - pred_boxes *= orig_target_sizes.repeat(1, 2).unsqueeze(1) + pred_boxes *= patch_size.repeat(1, 2).unsqueeze(1) values, pred_labels = torch.max(logits, dim=-1) @@ -55,13 +63,13 @@ def forward(self, outputs: Dict[str, Tensor], orig_target_sizes: Tensor): results = [] for i in range(logits.shape[0]): - score_keep = pred_scores[i] > self.score_threshold + score_keep = pred_scores[i] > score_threshold pred_box = pred_boxes[i][score_keep] pred_label = pred_labels[i][score_keep] pred_score = pred_scores[i][score_keep] - keep = torchvision.ops.batched_nms(pred_box, pred_score, pred_label, self.iou_threshold) - keep = keep[:self.keep_topk] + keep = torchvision.ops.batched_nms(pred_box, pred_score, pred_label, iou_threshold) + keep = keep[:keep_topk] blob = { 'labels': pred_label[keep], diff --git a/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py b/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py index dcac0df2..a090a80b 100644 --- a/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py +++ b/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py @@ -32,7 +32,8 @@ def __init__( num_classes=80, use_focal_loss=True, num_top_queries=300, - remap_mscoco_category=False + remap_mscoco_category=False, + image_dimensions=(640, 640) ) -> None: super().__init__() self.use_focal_loss = use_focal_loss @@ -40,17 +41,18 @@ def __init__( self.num_classes = int(num_classes) self.remap_mscoco_category = remap_mscoco_category self.deploy_mode = False + self.image_dimensions = image_dimensions def extra_repr(self) -> str: return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}' - # def forward(self, outputs, orig_target_sizes): - def forward(self, outputs, orig_target_sizes: torch.Tensor): + def forward(self, outputs): logits, boxes = outputs['pred_logits'], outputs['pred_boxes'] - # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + patch_size = torch.tensor(self.image_dimensions, dtype=torch.float32).to(boxes.device) + patch_size = patch_size.repeat(boxes.size(0), 1) # Repeat for batch size bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy') - bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1) + bbox_pred *= patch_size.repeat(1, 2).unsqueeze(1) if self.use_focal_loss: scores = F.sigmoid(logits) @@ -71,7 +73,15 @@ def forward(self, outputs, orig_target_sizes: torch.Tensor): # TODO for onnx export if self.deploy_mode: - return labels, boxes, scores + results = [] + for lab, box, sco in zip(labels, boxes, scores): + result = { + "labels": lab, + "boxes": box, + "scores": sco + } + results.append(result) + return results # TODO if self.remap_mscoco_category: diff --git a/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetrv2_criterion.py b/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetrv2_criterion.py index c69e3684..788af711 100644 --- a/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetrv2_criterion.py +++ b/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetrv2_criterion.py @@ -12,6 +12,7 @@ from .box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou from ...misc.dist_utils import get_world_size, is_dist_available_and_initialized from ...core import register +from .matcher import HungarianMatcher @register() @@ -25,10 +26,14 @@ class RTDETRCriterionv2(nn.Module): __inject__ = ['matcher', ] def __init__(self, \ - matcher, - weight_dict, - losses, - alpha=0.2, + matcher = HungarianMatcher( + weight_dict={'cost_class': 2, 'cost_bbox': 5, 'cost_giou': 2}, + use_focal_loss=True, + alpha=0.25, + gamma=2.0), + weight_dict={'loss_vfl': 1, 'loss_bbox': 5, 'loss_giou': 2}, + losses = ['vfl', 'boxes'], + alpha=0.75, gamma=2.0, num_classes=80, boxes_weight_format=None,