Skip to content

Commit

Permalink
Readability refactor + normalize classification loss
Browse files Browse the repository at this point in the history
  • Loading branch information
ggaziv committed May 4, 2020
1 parent 3fe4552 commit 02ef960
Showing 1 changed file with 14 additions and 23 deletions.
37 changes: 14 additions & 23 deletions efficientdet/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,34 +53,24 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):

classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)

if bbox_annotation.shape[0] == 0:
if torch.cuda.is_available():

alpha_factor = torch.ones_like(classification) * alpha
if len(bbox_annotation) == 0: # No annotations
alpha_factor = torch.ones_like(classification) * alpha
if torch.cuda.is_available():
alpha_factor = alpha_factor.cuda()
alpha_factor = 1. - alpha_factor
focal_weight = classification
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
alpha_factor = 1. - alpha_factor
focal_weight = classification
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

bce = -(torch.log(1.0 - classification))

cls_loss = focal_weight * bce
bce = -(torch.log(1.0 - classification))

cls_loss = focal_weight * bce
if torch.cuda.is_available():
regression_losses.append(torch.tensor(0).to(dtype).cuda())
classification_losses.append(cls_loss.sum())
else:

alpha_factor = torch.ones_like(classification) * alpha
alpha_factor = 1. - alpha_factor
focal_weight = classification
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

bce = -(torch.log(1.0 - classification))

cls_loss = focal_weight * bce

regression_losses.append(torch.tensor(0).to(dtype))
classification_losses.append(cls_loss.sum())

# classification_losses.append(cls_loss.sum())
classification_losses.append(cls_loss.mean())

continue

Expand Down Expand Up @@ -121,7 +111,8 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
zeros = zeros.cuda()
cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros)

classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
# classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
classification_losses.append(cls_loss.mean() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))

if positive_indices.sum() > 0:
assigned_annotations = assigned_annotations[positive_indices, :]
Expand Down

0 comments on commit 02ef960

Please sign in to comment.