-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathloss_comparison_style_loss.py
118 lines (82 loc) · 2.99 KB
/
loss_comparison_style_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.optim as optim
import imageio
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
from mdfloss import MDFLoss
from utils import psnr
from style_loss import *
model=VGG().to('cuda').eval()
from super_loss import SuperLoss
loss_function= SuperLoss('l1')
# Set parameters
cuda_available = True
epochs = 500
application = 'SISR'
image_path = './misc/i10.png'
if application =='SISR':
path_disc = "./weights/Ds_SISR.pth"
elif application == 'Denoising':
path_disc = "./weights/Ds_Denoising.pth"
elif application == 'JPEG':
path_disc = "./weights/Ds_JPEG.pth"
#%% Read reference images
imgr = imageio.imread(image_path)
imgr = torch.from_numpy(imageio.core.asarray(imgr/255.0))
imgr = imgr.type(dtype=torch.float64)
imgr = imgr.permute(2,0,1)
imgr = imgr.unsqueeze(0).type(torch.FloatTensor)
# Create a noisy image
#imgd = torch.rand(imgr.size())
#torch.save(imgd, 'noisy_test_img.pt')
imgd= torch.load('noisy_img.pt')
# Save the original state
imgdo = imgd.detach().clone()
if cuda_available:
imgr = imgr.cuda()
imgd = imgd.cuda()
# Convert images to variables to support gradients
imgrb = Variable( imgr, requires_grad = False)
imgdb = Variable( imgd, requires_grad = True)
optimizer = optim.Adam([imgdb], lr=0.1)
# Initialise the loss
#criterion = MDFLoss(path_disc, cuda_available=cuda_available)
criterion= loss_function.cuda()
PSNRs=[]
# Iterate over the epochs optimizing for the noisy image
for ii in range(0,epochs):
optimizer.zero_grad()
gen_features=model(imgdb)
orig_feautes=model(imgrb)
style_featues=model(imgrb)
content_l=torch.mean((gen_features-orig_feautes)**2)
content_l.backward()
# total_loss= calculate_loss(gen_features, orig_feautes, style_featues) + criterion(imgrb,imgdb)
# total_loss.backward()
optimizer.step()
# loss = criterion(imgrb,imgdb)
eval_psnr = psnr(torch.clamp(imgrb.cuda(), 0., 1.), torch.clamp(imgdb.cuda(), 0., 1.)).item()
print("Epoch: ",ii," loss: ", content_l.item(), eval_psnr)
# print('PSNR is Averaged', eval_psnr)
PSNRs.append(eval_psnr)
imgdnp = imgdb.cpu().squeeze(0).permute(1,2,0).data.numpy()
plt.imsave('Output_Images/VGG_style_loss.png', np.clip(imgdnp, 0.0, 1.0))
np.save('PSNR_Values/PSNRs_'+'VGG_style_loss'+'.npy', PSNRs)
# Convert images to numpy
imgrnp = imgr.cpu().squeeze(0).permute(1,2,0).data.numpy()
imgdnp = imgdb.cpu().squeeze(0).permute(1,2,0).data.numpy()
imgdonp = imgdo.cpu().squeeze(0).permute(1,2,0).data.numpy()
# Plot optimization results
fig, axs = plt.subplots(1, 3,figsize=(45,15))
axs[0].imshow(imgdonp)
axs[0].set_title('Noisy image',fontsize=48)
axs[1].imshow(imgdnp)
axs[1].set_title('Recovered image',fontsize=48)
axs[2].imshow(imgrnp)
axs[2].set_title('Reference image',fontsize=48)
plt.imsave('Output_Images/VGG_style_loss.png', np.clip(imgdnp, 0.0, 1.0))
# Remove the ticks from the axis
for ax in axs:
ax.set_xticks([])
ax.set_yticks([])