-
Notifications
You must be signed in to change notification settings - Fork 10
/
loss.py
35 lines (30 loc) · 1.29 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# --------------------------------------------------------
# Reversible Column Networks
# Copyright (c) 2022 Megvii Inc.
# Licensed under The Apache License 2.0 [see LICENSE for details]
# Written by Yuxuan Cai
# --------------------------------------------------------
from dis import dis
import torch
from torch import nn
import torch.distributed as dist
from torch.functional import Tensor
import torch.nn.functional as F
def compound_loss(coe, output_feature, image:Tensor, output_label, targets, criterion_bce, criterion_ce, epoch):
f_coe, c_coe = coe
image.clamp_(0.01, 0.99)
multi_loss = []
for i, feature in enumerate(output_feature):
ratio_f = 1 - i / len(output_feature)
ratio_c = (i+1) / (len(output_label))
ihx = criterion_bce(feature, image) * ratio_f * f_coe
ihy = criterion_ce(output_label[i], targets) * ratio_c * c_coe
# if dist.get_rank() == 0:
# print(f'ihx: {ihx}, ihy: {ihy}')
multi_loss.append(ihx + ihy)
# feature_loss.append(torch.dist(output_feature[i], teacher_feature) * feature_coe)
multi_loss.append(criterion_ce(output_label[-1], targets))
# print(feature_loss)
loss = torch.sum(torch.stack(multi_loss), dim=0)
# +torch.mean(torch.stack(feature_loss), dim=0)
return loss, multi_loss