forked from gned0/avalanche
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathself_supervised_losses.py
153 lines (112 loc) · 4.68 KB
/
self_supervised_losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimSiamLoss(nn.Module):
def __init__(self, version="simplified"):
super().__init__()
self.ver = version
def criterion(self, p, z):
if self.ver == "original":
z = z.detach() # stop gradient
p = nn.functional.normalize(p, dim=1)
z = nn.functional.normalize(z, dim=1)
return -(p * z).sum(dim=1).mean()
elif self.ver == "simplified":
z = z.detach() # stop gradient
return -nn.functional.cosine_similarity(p, z, dim=-1).mean()
def forward(self, out):
p1, p2 = out["p"]
z1, z2 = out["z"]
loss1 = self.criterion(p1, z2)
loss2 = self.criterion(p2, z1)
return 0.5 * loss1 + 0.5 * loss2
class BarlowTwinsLoss(nn.Module):
def __init__(self, lambd=5e-3, scale_loss=0.025, device="cuda"):
super().__init__()
self.lambd = lambd
self.scale_loss = scale_loss
self.device = device
def forward(self, out):
z1, z2 = out["z"]
z1_norm = (z1 - z1.mean(0)) / z1.std(0) # NxD
z2_norm = (z2 - z2.mean(0)) / z2.std(0) # NxD
N = z1.size(0)
D = z1.size(1)
corr = torch.einsum("bi, bj -> ij", z1_norm, z2_norm) / N
diag = torch.eye(D, device=corr.device)
cdif = (corr - diag).pow(2)
cdif[~diag.bool()] *= self.lambd
loss = self.scale_loss * cdif.sum()
return loss
class NTXentLoss(nn.Module):
def __init__(self, temperature=0.5):
super().__init__()
self.temperature = temperature
def forward(self, out):
z1, z2 = out["z"]
device = z1.device
b = z1.size(0)
z = torch.cat((z1, z2), dim=0)
z = F.normalize(z, dim=-1)
logits = torch.einsum("if, jf -> ij", z, z) / self.temperature
logits_max, _ = torch.max(logits, dim=1, keepdim=True)
logits = logits - logits_max.detach()
# positive mask are matches i, j (i from aug1, j from aug2), where i == j and matches j, i
pos_mask = torch.zeros((2 * b, 2 * b), dtype=torch.bool, device=device)
pos_mask[:, b:].fill_diagonal_(True)
pos_mask[b:, :].fill_diagonal_(True)
# all matches excluding the main diagonal
logit_mask = torch.ones_like(pos_mask, device=device).fill_diagonal_(0)
exp_logits = torch.exp(logits) * logit_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# compute mean of log-likelihood over positives
mean_log_prob_pos = (pos_mask * log_prob).sum(1) / pos_mask.sum(1)
# loss
loss = -mean_log_prob_pos.mean()
return loss
class ContrastiveDistillLoss(nn.Module):
def __init__(self, temperature: float = 0.2):
super(ContrastiveDistillLoss, self).__init__()
self.temperature = temperature
def forward(
self,
p1: torch.Tensor,
p2: torch.Tensor,
z1: torch.Tensor,
z2: torch.Tensor,
) -> torch.Tensor:
device = z1.device
b = z1.size(0)
# Normalize and concatenate predictions and representations.
p = F.normalize(torch.cat([p1, p2]), dim=-1) # (2*b, feature_dim)
z = F.normalize(torch.cat([z1, z2]), dim=-1) # (2*b, feature_dim)
# Compute similarity logits scaled by temperature
logits = torch.einsum("if, jf -> ij", p, z) / self.temperature
# For numerical stability subtract the maximum logit per sample.
logits_max, _ = torch.max(logits, dim=1, keepdim=True)
logits = logits - logits_max.detach()
pos_mask = torch.zeros((2 * b, 2 * b), dtype=torch.bool, device=device)
pos_mask.fill_diagonal_(True)
logit_mask = torch.ones_like(pos_mask, device=device)
logit_mask.fill_diagonal_(True)
logit_mask[:, b:].fill_diagonal_(True)
logit_mask[b:, :].fill_diagonal_(True)
exp_logits = torch.exp(logits) * logit_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
mean_log_prob_pos = (pos_mask * log_prob).sum(1) / pos_mask.sum(1)
loss = -mean_log_prob_pos.mean()
return loss
class BYOLLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, out) -> torch.Tensor:
p1, p2 = out["p"]
z1, z2 = out["z"]
p1 = F.normalize(p1, dim=-1)
p2 = F.normalize(p2, dim=-1)
z1 = F.normalize(z1, dim=-1)
z2 = F.normalize(z2, dim=-1)
# Mean Squared Error between p and z'
loss1 = 2 - 2 * (p1 * z2.detach()).sum(dim=-1)
loss2 = 2 - 2 * (p2 * z1.detach()).sum(dim=-1)
return (loss1 + loss2).mean()