forked from wzmsltw/BSN-boundary-sensitive-network.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss_function.py
executable file
·103 lines (85 loc) · 3.74 KB
/
loss_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn.functional as F
def bi_loss(scores, anchors, opt):
scores = scores.view(-1).cuda()
anchors = anchors.contiguous().view(-1)
l1 = torch.mean(torch.abs(scores - anchors))
pmask = (scores > opt["tem_match_thres"]).float().cuda()
num_positive = torch.sum(pmask)
num_entries = len(scores)
# I made it do the +1 below and the ratio + 1e-6
# Oct 1st --> Uncommented the following 2 lines because we saw NaNs in gymnastics.
###
ratio = (num_entries + 1) / (num_positive + 1)
ratio += 1e-6
###
# ratio = num_entries / num_positive
coef_0 = 0.5 * (ratio) / (ratio - 1)
coef_1 = coef_0 * (ratio - 1)
loss = coef_1 * pmask * torch.log(anchors + 0.00001) + coef_0 * (
1.0 - pmask) * torch.log(1.0 - anchors + 0.00001)
loss = -torch.mean(loss)
num_sample = [num_positive, ratio, num_entries]
return loss, num_sample, l1
def TEM_loss_calc(anchors_action, anchors_start, anchors_end,
match_scores_action, match_scores_start, match_scores_end,
opt):
action_loss, num_sample_action, action_l1 = bi_loss(match_scores_action,
anchors_action, opt)
start_loss, num_sample_start, start_l1 = bi_loss(match_scores_start,
anchors_start, opt)
end_loss, num_sample_end, end_l1 = bi_loss(match_scores_end,
anchors_end, opt)
loss_dict = {
"action_loss": action_loss,
"action_positive": num_sample_action[0],
"action_l1": action_l1,
"start_loss": start_loss,
"start_positive": num_sample_start[0],
"start_l1": start_l1,
"end_loss": end_loss,
"end_positive": num_sample_end[0],
"end_l1": end_l1,
"entries": num_sample_action[2]
}
return loss_dict
def TEM_loss_function(y_action, y_start, y_end, TEM_output, opt):
anchors_action = TEM_output[:, 0, :]
anchors_start = TEM_output[:, 1, :]
anchors_end = TEM_output[:, 2, :]
loss_dict = TEM_loss_calc(anchors_action, anchors_start, anchors_end,
y_action, y_start, y_end, opt)
cost = 2 * loss_dict["action_loss"] + loss_dict["start_loss"] + loss_dict[
"end_loss"]
loss_dict["cost"] = cost
return loss_dict
def PEM_loss_function(anchors_iou, match_iou, opt):
match_iou = match_iou.cuda()
anchors_iou = anchors_iou.view(-1)
u_hmask = (match_iou > opt["pem_high_iou_thres"]).float()
u_mmask = ((match_iou <= opt["pem_high_iou_thres"]) &
(match_iou > opt["pem_low_iou_thres"])).float()
u_lmask = (match_iou < opt["pem_low_iou_thres"]).float()
num_h = torch.sum(u_hmask)
num_m = torch.sum(u_mmask)
num_l = torch.sum(u_lmask)
r_m = opt['pem_u_ratio_m'] * num_h / (num_m)
r_m = torch.min(r_m, torch.Tensor([1.0]).cuda())[0]
u_smmask = torch.Tensor(np.random.rand(u_hmask.size()[0])).cuda()
u_smmask = u_smmask * u_mmask
u_smmask = (u_smmask > (1. - r_m)).float()
r_l = opt['pem_u_ratio_l'] * num_h / (num_l)
r_l = torch.min(r_l, torch.Tensor([1.0]).cuda())[0]
u_slmask = torch.Tensor(np.random.rand(u_hmask.size()[0])).cuda()
u_slmask = u_slmask * u_lmask
u_slmask = (u_slmask > (1. - r_l)).float()
iou_weights = u_hmask + u_smmask + u_slmask
iou_loss = F.smooth_l1_loss(anchors_iou, match_iou.squeeze())
# print('LOSS')
# print(iou_loss.shape, iou_weights.shape)
# print(iou_weights)
# iou_loss = torch.sum(iou_loss * iou_weights) / torch.sum(iou_weights)
iou_loss = torch.sum(iou_loss * iou_weights) / (1e-6 + torch.sum(iou_weights))
return {'iou_loss': iou_loss}