-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathlearner.py
executable file
·50 lines (41 loc) · 1.37 KB
/
learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn
from torch.nn import functional as F
class Learner(nn.Module):
def __init__(self, input_dim=2048, drop_p=0.0):
super(Learner, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Dropout(0.6),
nn.Linear(512, 32),
nn.ReLU(),
nn.Dropout(0.6),
nn.Linear(32, 1),
nn.Sigmoid()
)
self.drop_p = 0.6
self.weight_init()
self.vars = nn.ParameterList()
for i, param in enumerate(self.classifier.parameters()):
self.vars.append(param)
def weight_init(self):
for layer in self.classifier:
if type(layer) == nn.Linear:
nn.init.xavier_normal_(layer.weight)
def forward(self, x, vars=None):
if vars is None:
vars = self.vars
x = F.linear(x, vars[0], vars[1])
x = F.relu(x)
x = F.dropout(x, self.drop_p, training=self.training)
x = F.linear(x, vars[2], vars[3])
x = F.dropout(x, self.drop_p, training=self.training)
x = F.linear(x, vars[4], vars[5])
return torch.sigmoid(x)
def parameters(self):
"""
override this function since initial parameters will return with a generator.
:return:
"""
return self.vars