-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcloss.py
47 lines (42 loc) · 1.77 KB
/
closs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
class CovLoss(nn.Module):
def __init__(self, batchsize, num_instance, margin=0.00):
super(CovLoss, self).__init__()
self.margin = margin
self.batchsize = batchsize
self.num_instance = num_instance
self.sigmoid = nn.Sigmoid()
return
def forward(self, feat):
feat = self.sigmoid(feat)
loss = self.compute_indep(feat)
return loss
def compute_indep(self, feat):
feat_dim = feat.size(1)
mask = torch.eye(feat_dim).cuda()
feat_mean = torch.mean(feat, dim=0, keepdim=True)
feat_centerless = feat - feat_mean
feat_covar = torch.matmul(feat_centerless.t(), feat_centerless)
feat_var = feat_covar.diag().unsqueeze(1)
feat_cvar = torch.matmul(
feat_var, feat_var.t()).clamp(min=1e-12).sqrt()
feat_dep = torch.div(feat_covar, feat_cvar).abs()
masked_dep = torch.masked_select(feat_dep, mask == 0)
loss = torch.max(torch.zeros_like(masked_dep),
masked_dep-self.margin).mean()
return loss
def compute_dep(self, feat1, feat2):
feat1_dim = feat1.size(1)
feat1_mean = torch.mean(feat1, dim=0, keepdim=True)
feat2_mean = torch.mean(feat2, dim=0, keepdim=True)
feat1_cl = feat1 - feat1_mean
feat2_cl = feat2 - feat2_mean
feat_covar = torch.matmul(feat1_cl.t(), feat2_cl).diag()
feat1_var = torch.matmul(feat1_cl.t(), feat1_cl).diag()
feat2_var = torch.matmul(feat2_cl.t(), feat2_cl).diag()
feat_var = torch.mul(feat1_var, feat2_var).clamp(min=1e-12).sqrt()
feat_dep = torch.div(feat_covar, feat_var)
mask = torch.ones_like(feat_dep)
loss = (mask - feat_dep).mean()
return loss