-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmlosses.py
executable file
·183 lines (165 loc) · 8.27 KB
/
mlosses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!/usr/bin/env python
import torch
import torch.nn as nn
import sys
def distance_matrix_vector(anchor, positive):
"""Given batch of anchor descriptors and positive descriptors calculate distance matrix"""
d1_sq = torch.sum(anchor * anchor, dim=1).unsqueeze(-1)#[N,1]
d2_sq = torch.sum(positive * positive, dim=1).unsqueeze(-1)#[N,1]
# print("d1_sq",d1_sq)
# print("d2_sq",d2_sq)
eps = 1e-6
# print("d1",d1_sq.repeat(1, positive.size(0)))
# print("d2",torch.t(d2_sq.repeat(1, anchor.size(0))))
# print("bmm",- 2.0 * torch.bmm(anchor.unsqueeze(0), torch.t(positive).unsqueeze(0)).squeeze(0))
# print(torch.sqrt((d1_sq.repeat(1, positive.size(0)) + torch.t(d2_sq.repeat(1, anchor.size(0)))
# - 2.0 * torch.bmm(anchor.unsqueeze(0), torch.t(positive).unsqueeze(0)).squeeze(0))+eps))#(a-b)^2=a^2+b^2-2a*b
return torch.sqrt((d1_sq.repeat(1, positive.size(0)) + torch.t(d2_sq.repeat(1, anchor.size(0)))
- 2.0 * torch.bmm(anchor.unsqueeze(0), torch.t(positive).unsqueeze(0)).squeeze(0))+eps)#[N, N]
def distance_vectors_pairwise(anchor, positive, negative = None):
"""Given batch of anchor descriptors and positive descriptors calculate distance matrix"""
a_sq = torch.sum(anchor * anchor, dim=1)
p_sq = torch.sum(positive * positive, dim=1)
eps = 1e-8
d_a_p = torch.sqrt(a_sq + p_sq - 2*torch.sum(anchor * positive, dim = 1) + eps)
if negative is not None:
n_sq = torch.sum(negative * negative, dim=1)
d_a_n = torch.sqrt(a_sq + n_sq - 2*torch.sum(anchor * negative, dim = 1) + eps)
d_p_n = torch.sqrt(p_sq + n_sq - 2*torch.sum(positive * negative, dim = 1) + eps)
return d_a_p, d_a_n, d_p_n
return d_a_p
def loss_random_sampling(anchor, positive, negative, anchor_swap = False, margin = 1.0, loss_type = "triplet_margin"):
"""Loss with random sampling (no hard in batch).
"""
assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal."
assert anchor.size() == negative.size(), "Input sizes between positive and negative must be equal."
assert anchor.dim() == 2, "Inputd must be a 2D matrix."
eps = 1e-8
(pos, d_a_n, d_p_n) = distance_vectors_pairwise(anchor, positive, negative)
if anchor_swap:
min_neg = torch.min(d_a_n, d_p_n)
else:
min_neg = d_a_n
if loss_type == "triplet_margin":
loss = torch.clamp(margin + pos - min_neg, min=0.0)
elif loss_type == 'softmax':
exp_pos = torch.exp(2.0 - pos);
exp_den = exp_pos + torch.exp(2.0 - min_neg) + eps;
loss = - torch.log( exp_pos / exp_den )
elif loss_type == 'contrastive':
loss = torch.clamp(margin - min_neg, min=0.0) + pos;
else:
print ('Unknown loss type. Try triplet_margin, softmax or contrastive')
sys.exit(1)
loss = torch.mean(loss)
return loss
def loss_L2Net(anchor, positive, anchor_swap = False, margin = 1.0, loss_type = "triplet_margin"):
"""L2Net losses: using whole batch as negatives, not only hardest.
"""
assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal."
assert anchor.dim() == 2, "Inputd must be a 2D matrix."
eps = 1e-8
dist_matrix = distance_matrix_vector(anchor, positive)
eye = torch.autograd.Variable(torch.eye(dist_matrix.size(1))).cuda()
# print(eye)
# steps to filter out same patches that occur in distance matrix as negatives
pos1 = torch.diag(dist_matrix)
dist_without_min_on_diag = dist_matrix+eye*10
mask = (dist_without_min_on_diag.ge(0.008)-1)*-1
mask = mask.type_as(dist_without_min_on_diag)*10
dist_without_min_on_diag = dist_without_min_on_diag+mask
if loss_type == 'softmax':
exp_pos = torch.exp(2.0 - pos1);
exp_den = torch.sum(torch.exp(2.0 - dist_matrix),1) + eps;
loss = -torch.log( exp_pos / exp_den )
if anchor_swap:
exp_den1 = torch.sum(torch.exp(2.0 - dist_matrix),0) + eps;
loss += -torch.log( exp_pos / exp_den1 )
else:
print ('Only softmax loss works with L2Net sampling')
sys.exit(1)
loss = torch.mean(loss)
return loss
def loss_HardNet(anchor, positive, anchor_swap = False, anchor_ave = False,\
margin = 1.0, batch_reduce = 'min', loss_type = "triplet_margin"):
"""HardNet margin loss - calculates loss based on distance matrix based on positive distance and closest negative distance.
"""
assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal."
assert anchor.dim() == 2, "Inputd must be a 2D matrix."
eps = 1e-8
dist_matrix = distance_matrix_vector(anchor, positive) +eps
eye = torch.autograd.Variable(torch.eye(dist_matrix.size(1)))#.cuda()
# print("dist_mtx",dist_matrix)
# steps to filter out same patches that occur in distance matrix as negatives
pos1 = torch.diag(dist_matrix)
# print("pos1",pos1)#values on the diagonal
dist_without_min_on_diag = dist_matrix+eye*10#
# print("without min",dist_without_min_on_diag)
#print(dist_without_min_on_diag.ge(0.008))#whether larger than 0.008, true.float()=1, false=0? if false, too close?
mask = (dist_without_min_on_diag.ge(0.008).float()-1.0)*(-1)#true=0, false=-1*-1=1
#print("mask",mask)
#print("mask type",mask.type_as(dist_without_min_on_diag))
mask = mask.type_as(dist_without_min_on_diag)*10#if false->1, x10,
dist_without_min_on_diag = dist_without_min_on_diag+mask
if batch_reduce == 'min':
min_neg = torch.min(dist_without_min_on_diag,1)[0]#values list,indices list
# print("min",torch.min(dist_without_min_on_diag,1))
# print("min_neg",min_neg)
if anchor_swap:
min_neg2 = torch.min(dist_without_min_on_diag,0)[0]
# print("min_neg2",min_neg2)
min_neg = torch.min(min_neg,min_neg2)# find the min beyween them
# print("min_neg",min_neg)
if False:
dist_matrix_a = distance_matrix_vector(anchor, anchor)+ eps
dist_matrix_p = distance_matrix_vector(positive,positive)+eps
dist_without_min_on_diag_a = dist_matrix_a+eye*10
dist_without_min_on_diag_p = dist_matrix_p+eye*10
min_neg_a = torch.min(dist_without_min_on_diag_a,1)[0]
min_neg_p = torch.t(torch.min(dist_without_min_on_diag_p,0)[0])
min_neg_3 = torch.min(min_neg_p,min_neg_a)
min_neg = torch.min(min_neg,min_neg_3)
print (min_neg_a)
print (min_neg_p)
print (min_neg_3)
print (min_neg)
min_neg = min_neg
pos = pos1#diagonal
elif batch_reduce == 'average':
pos = pos1.repeat(anchor.size(0)).view(-1,1).squeeze(0)
min_neg = dist_without_min_on_diag.view(-1,1)
if anchor_swap:
min_neg2 = torch.t(dist_without_min_on_diag).contiguous().view(-1,1)
min_neg = torch.min(min_neg,min_neg2)
min_neg = min_neg.squeeze(0)
elif batch_reduce == 'random':
idxs = torch.autograd.Variable(torch.randperm(anchor.size()[0]).long()).cuda()
min_neg = dist_without_min_on_diag.gather(1,idxs.view(-1,1))
if anchor_swap:
min_neg2 = torch.t(dist_without_min_on_diag).gather(1,idxs.view(-1,1))
min_neg = torch.min(min_neg,min_neg2)
min_neg = torch.t(min_neg).squeeze(0)
pos = pos1
else:
print ('Unknown batch reduce mode. Try min, average or random')
sys.exit(1)
if loss_type == "triplet_margin":
loss = torch.clamp(margin + pos - min_neg, min=0.0)#margin 1.0, tolerance
# print("loss",loss)
elif loss_type == 'softmax':
exp_pos = torch.exp(2.0 - pos);
exp_den = exp_pos + torch.exp(2.0 - min_neg) + eps;
loss = - torch.log( exp_pos / exp_den )
elif loss_type == 'contrastive':
loss = torch.clamp(margin - min_neg, min=0.0) + pos;
else:
print ('Unknown loss type. Try triplet_margin, softmax or contrastive')
sys.exit(1)
loss = torch.mean(loss)
# print(loss)
return loss
def global_orthogonal_regularization(anchor, negative):
neg_dis = torch.sum(torch.mul(anchor,negative),1)
dim = anchor.size(1)
gor = torch.pow(torch.mean(neg_dis),2) + torch.clamp(torch.mean(torch.pow(neg_dis,2))-1.0/dim, min=0.0)
return gor