forked from jiaqi-jiang/GLOnet_for_thin_film
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnet.py
116 lines (86 loc) · 3.82 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
import torch.nn.functional as F
class Generator(nn.Module):
def __init__(self, params):
super().__init__()
self.noise_dim = params.noise_dim
self.thickness_sup = params.thickness_sup
self.N_layers = params.N_layers
self.M_materials = params.M_materials
self.n_database = params.n_database.view(1, 1, params.M_materials, -1).cuda() # 1 x 1 x number of mat x number of freq
self.FC = nn.Sequential(
nn.Linear(self.noise_dim, self.N_layers * (self.M_materials + 1)),
nn.BatchNorm1d(self.N_layers * (self.M_materials + 1))
)
def forward(self, noise, alpha):
net = self.FC(noise)
net = net.view(-1, self.N_layers, self.M_materials + 1)
thicknesses = torch.sigmoid(net[:, :, 0]) * self.thickness_sup
X = net[:, :, 1:]
P = F.softmax(X * alpha, dim = 2).unsqueeze(-1) # batch size x number of layer x number of mat x 1
refractive_indices = torch.sum(P * self.n_database, dim=2) # batch size x number of layer x number of freq
return (thicknesses, refractive_indices, P.squeeze())
class ResBlock(nn.Module):
"""docstring for ResBlock"""
def __init__(self, dim=16):
super(ResBlock, self).__init__()
self.block = nn.Sequential(
nn.Linear(dim, dim*2, bias=False),
nn.BatchNorm1d(dim*2),
nn.LeakyReLU(0.2),
nn.Linear(dim*2, dim, bias=False),
nn.BatchNorm1d(dim))
def forward(self, x):
return F.leaky_relu(self.block(x) + x, 0.2)
'''
class ResBlock(nn.Module):
"""docstring for ResBlock"""
def __init__(self, dim=64):
super(ResBlock, self).__init__()
self.block = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.BatchNorm1d(dim),
nn.LeakyReLU(0.2))
def forward(self, x):
return x + self.block(x)
'''
class ResGenerator(nn.Module):
def __init__(self, params):
super().__init__()
self.noise_dim = params.noise_dim
self.res_layers = params.res_layers
self.res_dim = params.res_dim
self.thickness_sup = params.thickness_sup
self.N_layers = params.N_layers
self.M_materials = params.M_materials
self.n_database = params.n_database.view(1, 1, params.M_materials, -1).cuda() # 1 x 1 x number of mat x number of freq
self.initBLOCK = nn.Sequential(
nn.Linear(self.noise_dim, self.res_dim),
nn.LeakyReLU(0.2),
nn.Dropout(p=0.2)
)
self.endBLOCK = nn.Sequential(
nn.Linear(self.res_dim, self.N_layers * (self.M_materials + 1), bias=False),
nn.BatchNorm1d(self.N_layers * (self.M_materials + 1)),
)
self.ResBLOCK = nn.ModuleList()
for i in range(params.res_layers):
self.ResBLOCK.append(ResBlock(self.res_dim))
self.FC_thickness = nn.Sequential(
nn.Linear(self.N_layers, 16),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(16),
nn.Linear(16, self.N_layers),
)
def forward(self, noise, alpha):
net = self.initBLOCK(noise)
for i in range(self.res_layers):
self.ResBLOCK[i](net)
net = self.endBLOCK(net)
net = net.view(-1, self.N_layers, self.M_materials + 1)
thicknesses = torch.sigmoid(self.FC_thickness(net[:, :, 0])) * self.thickness_sup
X = net[:, :, 1:]
P = F.softmax(X * alpha, dim = 2).unsqueeze(-1) # batch size x number of layer x number of mat x 1
refractive_indices = torch.sum(P * self.n_database, dim=2) # batch size x number of layer x number of freq
return (thicknesses, refractive_indices, P.squeeze())