forked from MayankSingal/PyTorch-Image-Dehazing
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnet.py
50 lines (25 loc) · 877 Bytes
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
import math
class dehaze_net(nn.Module):
def __init__(self):
super(dehaze_net, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.e_conv1 = nn.Conv2d(3,3,1,1,0,bias=True)
self.e_conv2 = nn.Conv2d(3,3,3,1,1,bias=True)
self.e_conv3 = nn.Conv2d(6,3,5,1,2,bias=True)
self.e_conv4 = nn.Conv2d(6,3,7,1,3,bias=True)
self.e_conv5 = nn.Conv2d(12,3,3,1,1,bias=True)
def forward(self, x):
source = []
source.append(x)
x1 = self.relu(self.e_conv1(x))
x2 = self.relu(self.e_conv2(x1))
concat1 = torch.cat((x1,x2), 1)
x3 = self.relu(self.e_conv3(concat1))
concat2 = torch.cat((x2, x3), 1)
x4 = self.relu(self.e_conv4(concat2))
concat3 = torch.cat((x1,x2,x3,x4),1)
x5 = self.relu(self.e_conv5(concat3))
clean_image = self.relu((x5 * x) - x5 + 1)
return clean_image