Skip to content

Commit

Permalink
configuration of RT-detr
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanvdpalen committed Sep 24, 2024
1 parent 1497151 commit d6186dc
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 23 deletions.
11 changes: 5 additions & 6 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def _load_checkpoint(path: str, map_location='cpu'):
return state


def _build_model(args, ):
def _build_model(args, num_classes=80):
"""main
"""
cfg = YAMLConfig(args.config)
cfg = YAMLConfig(args.config, num_classes=num_classes)

if args.resume:
checkpoint = _load_checkpoint(args.resume, map_location='cpu')
Expand All @@ -47,14 +47,13 @@ class Model(nn.Module):
def __init__(self, ) -> None:
super().__init__()
self.model = cfg.model.deploy()
self.postprocessor = cfg.postprocessor.deploy()
self.postprocessor = cfg.postprocessor

def forward(self, images, orig_target_sizes):
def forward(self, images):
outputs = self.model(images)
outputs = self.postprocessor(outputs, orig_target_sizes)
return outputs

return Model()
return cfg.model


CONFIG = {
Expand Down
5 changes: 5 additions & 0 deletions rtdetrv2_pytorch/src/core/yaml_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ def __init__(self, cfg_path: str, **kwargs) -> None:
super().__init__()

cfg = load_config(cfg_path)

if 'num_classes' in kwargs:
cfg['num_classes'] = kwargs['num_classes']
print(f"Overriding num_classes in cfg with {kwargs['num_classes']}")

cfg = merge_dict(cfg, kwargs)

self.yaml_cfg = copy.deepcopy(cfg)
Expand Down
22 changes: 15 additions & 7 deletions rtdetrv2_pytorch/src/nn/postprocessor/nms_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
class DetNMSPostProcessor(torch.nn.Module):
def __init__(self, \
iou_threshold=0.7,
score_threshold=0.01,
score_threshold=0.1,
keep_topk=300,
box_fmt='cxcywh',
logit_fmt='sigmoid') -> None:
logit_fmt='sigmoid',
image_dimensions=(640,640)) -> None:
super().__init__()
self.iou_threshold = iou_threshold
self.score_threshold = score_threshold
Expand All @@ -31,11 +32,18 @@ def __init__(self, \
self.logit_fmt = logit_fmt.lower()
self.logit_func = getattr(F, self.logit_fmt, None)
self.deploy_mode = False
self.image_dimensions = image_dimensions

def forward(self, outputs: Dict[str, Tensor], orig_target_sizes: Tensor):
def forward(self, outputs: Dict[str, Tensor], **kwargs):
iou_threshold = kwargs.get('iou_threshold', self.iou_threshold)
score_threshold = kwargs.get('score_threshold', self.score_threshold)
keep_topk = kwargs.get('keep_topk', self.keep_topk)
logits, boxes = outputs['pred_logits'], outputs['pred_boxes']
patch_size = torch.tensor(self.image_dimensions, dtype=torch.float32).to(boxes.device)
patch_size = patch_size.repeat(boxes.size(0), 1) # Repeat for batch size

pred_boxes = torchvision.ops.box_convert(boxes, in_fmt=self.box_fmt, out_fmt='xyxy')
pred_boxes *= orig_target_sizes.repeat(1, 2).unsqueeze(1)
pred_boxes *= patch_size.repeat(1, 2).unsqueeze(1)

values, pred_labels = torch.max(logits, dim=-1)

Expand All @@ -55,13 +63,13 @@ def forward(self, outputs: Dict[str, Tensor], orig_target_sizes: Tensor):

results = []
for i in range(logits.shape[0]):
score_keep = pred_scores[i] > self.score_threshold
score_keep = pred_scores[i] > score_threshold
pred_box = pred_boxes[i][score_keep]
pred_label = pred_labels[i][score_keep]
pred_score = pred_scores[i][score_keep]

keep = torchvision.ops.batched_nms(pred_box, pred_score, pred_label, self.iou_threshold)
keep = keep[:self.keep_topk]
keep = torchvision.ops.batched_nms(pred_box, pred_score, pred_label, iou_threshold)
keep = keep[:keep_topk]

blob = {
'labels': pred_label[keep],
Expand Down
22 changes: 16 additions & 6 deletions rtdetrv2_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,27 @@ def __init__(
num_classes=80,
use_focal_loss=True,
num_top_queries=300,
remap_mscoco_category=False
remap_mscoco_category=False,
image_dimensions=(640, 640)
) -> None:
super().__init__()
self.use_focal_loss = use_focal_loss
self.num_top_queries = num_top_queries
self.num_classes = int(num_classes)
self.remap_mscoco_category = remap_mscoco_category
self.deploy_mode = False
self.image_dimensions = image_dimensions

def extra_repr(self) -> str:
return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}'

# def forward(self, outputs, orig_target_sizes):
def forward(self, outputs, orig_target_sizes: torch.Tensor):
def forward(self, outputs):
logits, boxes = outputs['pred_logits'], outputs['pred_boxes']
# orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
patch_size = torch.tensor(self.image_dimensions, dtype=torch.float32).to(boxes.device)
patch_size = patch_size.repeat(boxes.size(0), 1) # Repeat for batch size

bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy')
bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1)
bbox_pred *= patch_size.repeat(1, 2).unsqueeze(1)

if self.use_focal_loss:
scores = F.sigmoid(logits)
Expand All @@ -71,7 +73,15 @@ def forward(self, outputs, orig_target_sizes: torch.Tensor):

# TODO for onnx export
if self.deploy_mode:
return labels, boxes, scores
results = []
for lab, box, sco in zip(labels, boxes, scores):
result = {
"labels": lab,
"boxes": box,
"scores": sco
}
results.append(result)
return results

# TODO
if self.remap_mscoco_category:
Expand Down
13 changes: 9 additions & 4 deletions rtdetrv2_pytorch/src/zoo/rtdetr/rtdetrv2_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
from ...misc.dist_utils import get_world_size, is_dist_available_and_initialized
from ...core import register
from .matcher import HungarianMatcher


@register()
Expand All @@ -25,10 +26,14 @@ class RTDETRCriterionv2(nn.Module):
__inject__ = ['matcher', ]

def __init__(self, \
matcher,
weight_dict,
losses,
alpha=0.2,
matcher = HungarianMatcher(
weight_dict={'cost_class': 2, 'cost_bbox': 5, 'cost_giou': 2},
use_focal_loss=True,
alpha=0.25,
gamma=2.0),
weight_dict={'loss_vfl': 1, 'loss_bbox': 5, 'loss_giou': 2},
losses = ['vfl', 'boxes'],
alpha=0.75,
gamma=2.0,
num_classes=80,
boxes_weight_format=None,
Expand Down

0 comments on commit d6186dc

Please sign in to comment.