Skip to content

Commit

Permalink
Fixed classification loss normalization + support for matched/unmatch…
Browse files Browse the repository at this point in the history
…ed anchors
  • Loading branch information
ggaziv committed May 12, 2020
1 parent 02ef960 commit 0b5a328
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
27 changes: 17 additions & 10 deletions efficientdet/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ def calc_iou(a, b):


class FocalLoss(nn.Module):
def __init__(self):
def __init__(self, matched_threshold=0.5, unmatched_threshold=0.4, negatives_lower_than_unmatched=False):
super(FocalLoss, self).__init__()
self.matched_threshold = matched_threshold
self.unmatched_threshold = unmatched_threshold
self.negatives_lower_than_unmatched = negatives_lower_than_unmatched

def forward(self, classifications, regressions, anchors, annotations, **kwargs):
alpha = 0.25
Expand Down Expand Up @@ -69,9 +72,7 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
else:
regression_losses.append(torch.tensor(0).to(dtype))

# classification_losses.append(cls_loss.sum())
classification_losses.append(cls_loss.mean())

classification_losses.append(cls_loss.sum())
continue

IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4])
Expand All @@ -83,12 +84,19 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
if torch.cuda.is_available():
targets = targets.cuda()

targets[torch.lt(IoU_max, 0.4), :] = 0

positive_indices = torch.ge(IoU_max, 0.5)
if self.negatives_lower_than_unmatched:
# negative matches are the ones below the unmatched_threshold, whereas ignored matches are in between the matched and unmatched
targets[torch.lt(IoU_max, self.matched_threshold) & torch.ge(IoU_max, self.unmatched_threshold), :] = 0
else:
# Ignore targets with overlap lower than unmatched_threshold
targets[torch.lt(IoU_max, self.unmatched_threshold), :] = 0

num_positive_anchors = positive_indices.sum()
# Find all positives in a batch for normalization
positive_indices = torch.ge(IoU_max, self.matched_threshold)

# Avoid zero sum of num_positives, which would lead to inf loss during training
num_positive_anchors = positive_indices.sum() + 1
# print(num_positive_anchors)
assigned_annotations = bbox_annotation[IoU_argmax, :]

targets[positive_indices, :] = 0
Expand All @@ -111,8 +119,7 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):
zeros = zeros.cuda()
cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros)

# classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
classification_losses.append(cls_loss.mean() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))

if positive_indices.sum() > 0:
assigned_annotations = assigned_annotations[positive_indices, :]
Expand Down
18 changes: 13 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_args():
parser.add_argument('-c', '--compound_coef', type=int, default=0, help='coefficients of efficientdet')
parser.add_argument('-n', '--num_workers', type=int, default=12, help='num_workers of dataloader')
parser.add_argument('--batch_size', type=int, default=12, help='The number of images per batch among all devices')
parser.add_argument('--head_only', type=bool, default=False,
parser.add_argument('--head_only', type=boolean_string, default=False,
help='whether finetunes only the regressor and the classifier, '
'useful in early stage convergence or small/easy dataset')
parser.add_argument('--lr', type=float, default=1e-4)
Expand All @@ -56,17 +56,25 @@ def get_args():
parser.add_argument('-w', '--load_weights', type=str, default=None,
help='whether to load weights from a checkpoint, set None to initialize, set \'last\' to load last checkpoint')
parser.add_argument('--saved_path', type=str, default='logs/')
parser.add_argument('--debug', type=bool, default=False, help='whether visualize the predicted boxes of trainging, '
parser.add_argument('--debug', type=boolean_string, default=False, help='whether visualize the predicted boxes of trainging, '
'the output images will be in test/')
parser.add_argument('--matched_threshold', type=float, default=.5, help='Threshold for positive matches.')
parser.add_argument('--unmatched_threshold', type=float, default=.4, help='Threshold for negative matches.')

args = parser.parse_args()
return args


def boolean_string(s):
if s not in {'False', 'True'}:
raise ValueError('Not a valid boolean string')
return s == 'True'


class ModelWithLoss(nn.Module):
def __init__(self, model, debug=False):
def __init__(self, model, matched_threshold=0.5, unmatched_threshold=0.4, debug=False):
super().__init__()
self.criterion = FocalLoss()
self.criterion = FocalLoss(matched_threshold=matched_threshold, unmatched_threshold=unmatched_threshold)
self.model = model
self.debug = debug

Expand Down Expand Up @@ -175,7 +183,7 @@ def freeze_backbone(m):
writer = SummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

# warp the model with loss function, to reduce the memory usage on gpu0 and speedup
model = ModelWithLoss(model, debug=opt.debug)
model = ModelWithLoss(model, matched_threshold=opt.matched_threshold, unmatched_threshold=opt.unmatched_threshold, debug=opt.debug)

if params.num_gpus > 0:
model = model.cuda()
Expand Down

0 comments on commit 0b5a328

Please sign in to comment.