diff --git a/efficientdet/loss.py b/efficientdet/loss.py index 4da864f78..600c20dcc 100644 --- a/efficientdet/loss.py +++ b/efficientdet/loss.py @@ -25,8 +25,11 @@ def calc_iou(a, b): class FocalLoss(nn.Module): - def __init__(self): + def __init__(self, matched_threshold=0.5, unmatched_threshold=0.4, negatives_lower_than_unmatched=True): super(FocalLoss, self).__init__() + self.matched_threshold = matched_threshold + self.unmatched_threshold = unmatched_threshold + self.negatives_lower_than_unmatched = negatives_lower_than_unmatched def forward(self, classifications, regressions, anchors, annotations, **kwargs): alpha = 0.25 @@ -69,9 +72,7 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs): else: regression_losses.append(torch.tensor(0).to(dtype)) - # classification_losses.append(cls_loss.sum()) - classification_losses.append(cls_loss.mean()) - + classification_losses.append(cls_loss.sum()) continue IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4]) @@ -79,16 +80,23 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs): IoU_max, IoU_argmax = torch.max(IoU, dim=1) # compute the loss for classification - targets = torch.ones_like(classification) * -1 + targets = torch.ones_like(classification) * -1 # init by ignoring all targets if torch.cuda.is_available(): targets = targets.cuda() - targets[torch.lt(IoU_max, 0.4), :] = 0 - - positive_indices = torch.ge(IoU_max, 0.5) + if self.negatives_lower_than_unmatched: + # negative matches are the ones below the unmatched_threshold + targets[torch.lt(IoU_max, self.unmatched_threshold), :] = 0 + else: + # negative matches are in between the matched and unmatched + targets[torch.lt(IoU_max, self.matched_threshold) & torch.ge(IoU_max, self.unmatched_threshold), :] = 0 - num_positive_anchors = positive_indices.sum() + # Find all positives in a batch for normalization + positive_indices = torch.ge(IoU_max, self.matched_threshold) + # Avoid zero sum of num_positives, which would lead to inf loss during training + num_positive_anchors = positive_indices.sum() + 1 + # print(num_positive_anchors) assigned_annotations = bbox_annotation[IoU_argmax, :] targets[positive_indices, :] = 0 @@ -111,8 +119,7 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs): zeros = zeros.cuda() cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros) - # classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0)) - classification_losses.append(cls_loss.mean() / torch.clamp(num_positive_anchors.to(dtype), min=1.0)) + classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0)) if positive_indices.sum() > 0: assigned_annotations = assigned_annotations[positive_indices, :] diff --git a/train.py b/train.py index a25fe759d..126583da9 100644 --- a/train.py +++ b/train.py @@ -37,7 +37,7 @@ def get_args(): parser.add_argument('-c', '--compound_coef', type=int, default=0, help='coefficients of efficientdet') parser.add_argument('-n', '--num_workers', type=int, default=12, help='num_workers of dataloader') parser.add_argument('--batch_size', type=int, default=12, help='The number of images per batch among all devices') - parser.add_argument('--head_only', type=bool, default=False, + parser.add_argument('--head_only', type=boolean_string, default=False, help='whether finetunes only the regressor and the classifier, ' 'useful in early stage convergence or small/easy dataset') parser.add_argument('--lr', type=float, default=1e-4) @@ -56,17 +56,25 @@ def get_args(): parser.add_argument('-w', '--load_weights', type=str, default=None, help='whether to load weights from a checkpoint, set None to initialize, set \'last\' to load last checkpoint') parser.add_argument('--saved_path', type=str, default='logs/') - parser.add_argument('--debug', type=bool, default=False, help='whether visualize the predicted boxes of trainging, ' + parser.add_argument('--debug', type=boolean_string, default=False, help='whether visualize the predicted boxes of trainging, ' 'the output images will be in test/') + parser.add_argument('--matched_threshold', type=float, default=.5, help='Threshold for positive matches.') + parser.add_argument('--unmatched_threshold', type=float, default=.4, help='Threshold for negative matches.') args = parser.parse_args() return args +def boolean_string(s): + if s not in {'False', 'True'}: + raise ValueError('Not a valid boolean string') + return s == 'True' + + class ModelWithLoss(nn.Module): - def __init__(self, model, debug=False): + def __init__(self, model, matched_threshold=0.5, unmatched_threshold=0.4, debug=False): super().__init__() - self.criterion = FocalLoss() + self.criterion = FocalLoss(matched_threshold=matched_threshold, unmatched_threshold=unmatched_threshold) self.model = model self.debug = debug @@ -175,7 +183,7 @@ def freeze_backbone(m): writer = SummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/') # warp the model with loss function, to reduce the memory usage on gpu0 and speedup - model = ModelWithLoss(model, debug=opt.debug) + model = ModelWithLoss(model, matched_threshold=opt.matched_threshold, unmatched_threshold=opt.unmatched_threshold, debug=opt.debug) if params.num_gpus > 0: model = model.cuda()