Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Readability refactor + normalize classification loss #252

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 27 additions & 29 deletions efficientdet/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ def calc_iou(a, b):


class FocalLoss(nn.Module):
def __init__(self):
def __init__(self, matched_threshold=0.5, unmatched_threshold=0.4, negatives_lower_than_unmatched=True):
super(FocalLoss, self).__init__()
self.matched_threshold = matched_threshold
self.unmatched_threshold = unmatched_threshold
self.negatives_lower_than_unmatched = negatives_lower_than_unmatched

def forward(self, classifications, regressions, anchors, annotations, **kwargs):
alpha = 0.25
Expand All @@ -53,52 +56,47 @@ def forward(self, classifications, regressions, anchors, annotations, **kwargs):

classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)

if bbox_annotation.shape[0] == 0:
if torch.cuda.is_available():

alpha_factor = torch.ones_like(classification) * alpha
if len(bbox_annotation) == 0: # No annotations
alpha_factor = torch.ones_like(classification) * alpha
if torch.cuda.is_available():
alpha_factor = alpha_factor.cuda()
alpha_factor = 1. - alpha_factor
focal_weight = classification
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

bce = -(torch.log(1.0 - classification))
alpha_factor = 1. - alpha_factor
focal_weight = classification
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

cls_loss = focal_weight * bce
bce = -(torch.log(1.0 - classification))

cls_loss = focal_weight * bce
if torch.cuda.is_available():
regression_losses.append(torch.tensor(0).to(dtype).cuda())
classification_losses.append(cls_loss.sum())
else:

alpha_factor = torch.ones_like(classification) * alpha
alpha_factor = 1. - alpha_factor
focal_weight = classification
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

bce = -(torch.log(1.0 - classification))

cls_loss = focal_weight * bce

regression_losses.append(torch.tensor(0).to(dtype))
classification_losses.append(cls_loss.sum())


classification_losses.append(cls_loss.sum())
continue

IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4])

IoU_max, IoU_argmax = torch.max(IoU, dim=1)

# compute the loss for classification
targets = torch.ones_like(classification) * -1
targets = torch.ones_like(classification) * -1 # init by ignoring all targets
if torch.cuda.is_available():
targets = targets.cuda()

targets[torch.lt(IoU_max, 0.4), :] = 0

positive_indices = torch.ge(IoU_max, 0.5)
if self.negatives_lower_than_unmatched:
# negative matches are the ones below the unmatched_threshold
targets[torch.lt(IoU_max, self.unmatched_threshold), :] = 0
else:
# negative matches are in between the matched and unmatched
targets[torch.lt(IoU_max, self.matched_threshold) & torch.ge(IoU_max, self.unmatched_threshold), :] = 0

num_positive_anchors = positive_indices.sum()
# Find all positives in a batch for normalization
positive_indices = torch.ge(IoU_max, self.matched_threshold)

# Avoid zero sum of num_positives, which would lead to inf loss during training
num_positive_anchors = positive_indices.sum() + 1
# print(num_positive_anchors)
assigned_annotations = bbox_annotation[IoU_argmax, :]

targets[positive_indices, :] = 0
Expand Down
18 changes: 13 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_args():
parser.add_argument('-c', '--compound_coef', type=int, default=0, help='coefficients of efficientdet')
parser.add_argument('-n', '--num_workers', type=int, default=12, help='num_workers of dataloader')
parser.add_argument('--batch_size', type=int, default=12, help='The number of images per batch among all devices')
parser.add_argument('--head_only', type=bool, default=False,
parser.add_argument('--head_only', type=boolean_string, default=False,
help='whether finetunes only the regressor and the classifier, '
'useful in early stage convergence or small/easy dataset')
parser.add_argument('--lr', type=float, default=1e-4)
Expand All @@ -56,17 +56,25 @@ def get_args():
parser.add_argument('-w', '--load_weights', type=str, default=None,
help='whether to load weights from a checkpoint, set None to initialize, set \'last\' to load last checkpoint')
parser.add_argument('--saved_path', type=str, default='logs/')
parser.add_argument('--debug', type=bool, default=False, help='whether visualize the predicted boxes of trainging, '
parser.add_argument('--debug', type=boolean_string, default=False, help='whether visualize the predicted boxes of trainging, '
'the output images will be in test/')
parser.add_argument('--matched_threshold', type=float, default=.5, help='Threshold for positive matches.')
parser.add_argument('--unmatched_threshold', type=float, default=.4, help='Threshold for negative matches.')

args = parser.parse_args()
return args


def boolean_string(s):
if s not in {'False', 'True'}:
raise ValueError('Not a valid boolean string')
return s == 'True'


class ModelWithLoss(nn.Module):
def __init__(self, model, debug=False):
def __init__(self, model, matched_threshold=0.5, unmatched_threshold=0.4, debug=False):
super().__init__()
self.criterion = FocalLoss()
self.criterion = FocalLoss(matched_threshold=matched_threshold, unmatched_threshold=unmatched_threshold)
self.model = model
self.debug = debug

Expand Down Expand Up @@ -175,7 +183,7 @@ def freeze_backbone(m):
writer = SummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

# warp the model with loss function, to reduce the memory usage on gpu0 and speedup
model = ModelWithLoss(model, debug=opt.debug)
model = ModelWithLoss(model, matched_threshold=opt.matched_threshold, unmatched_threshold=opt.unmatched_threshold, debug=opt.debug)

if params.num_gpus > 0:
model = model.cuda()
Expand Down