diff --git a/efficientdet/loss.py b/efficientdet/loss.py index 082416c2e..4da864f78 100644 --- a/efficientdet/loss.py +++ b/efficientdet/loss.py @@ -53,34 +53,24 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs): classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4) - if bbox_annotation.shape[0] == 0: - if torch.cuda.is_available(): - - alpha_factor = torch.ones_like(classification) * alpha + if len(bbox_annotation) == 0: # No annotations + alpha_factor = torch.ones_like(classification) * alpha + if torch.cuda.is_available(): alpha_factor = alpha_factor.cuda() - alpha_factor = 1. - alpha_factor - focal_weight = classification - focal_weight = alpha_factor * torch.pow(focal_weight, gamma) + alpha_factor = 1. - alpha_factor + focal_weight = classification + focal_weight = alpha_factor * torch.pow(focal_weight, gamma) - bce = -(torch.log(1.0 - classification)) - - cls_loss = focal_weight * bce + bce = -(torch.log(1.0 - classification)) + cls_loss = focal_weight * bce + if torch.cuda.is_available(): regression_losses.append(torch.tensor(0).to(dtype).cuda()) - classification_losses.append(cls_loss.sum()) else: - - alpha_factor = torch.ones_like(classification) * alpha - alpha_factor = 1. - alpha_factor - focal_weight = classification - focal_weight = alpha_factor * torch.pow(focal_weight, gamma) - - bce = -(torch.log(1.0 - classification)) - - cls_loss = focal_weight * bce - regression_losses.append(torch.tensor(0).to(dtype)) - classification_losses.append(cls_loss.sum()) + + # classification_losses.append(cls_loss.sum()) + classification_losses.append(cls_loss.mean()) continue @@ -121,7 +111,8 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs): zeros = zeros.cuda() cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros) - classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0)) + # classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0)) + classification_losses.append(cls_loss.mean() / torch.clamp(num_positive_anchors.to(dtype), min=1.0)) if positive_indices.sum() > 0: assigned_annotations = assigned_annotations[positive_indices, :]