-
Notifications
You must be signed in to change notification settings - Fork 16
/
deepvonet.py
76 lines (70 loc) · 2.99 KB
/
deepvonet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd import Variable
class DeepVONet(nn.Module):
def __init__(self):
super(DeepVONet, self).__init__()
self.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d (64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d (128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
self.relu3 = nn.ReLU(inplace=True)
self.conv3_1 = nn.Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu3_1 = nn.ReLU(inplace=True)
self.conv4 = nn.Conv2d (256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.relu4 = nn.ReLU(inplace=True)
self.conv4_1 = nn.Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu4_1 = nn.ReLU(inplace=True)
self.conv5 = nn.Conv2d (512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.relu5 = nn.ReLU(inplace=True)
self.conv5_1 = nn.Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu5_1 = nn.ReLU(inplace=True)
self.conv6 = nn.Conv2d (512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.lstm1 = nn.LSTMCell(20*6*1024, 100)
self.lstm2 = nn.LSTMCell(100, 100)
self.fc = nn.Linear(in_features=100, out_features=6)
self.reset_hidden_states()
def reset_hidden_states(self, size=1, zero=True):
if zero == True:
self.hx1 = Variable(torch.zeros(size, 100))
self.cx1 = Variable(torch.zeros(size, 100))
self.hx2 = Variable(torch.zeros(size, 100))
self.cx2 = Variable(torch.zeros(size, 100))
else:
self.hx1 = Variable(self.hx1.data)
self.cx1 = Variable(self.cx1.data)
self.hx2 = Variable(self.hx2.data)
self.cx2 = Variable(self.cx2.data)
if next(self.parameters()).is_cuda == True:
self.hx1 = self.hx1.cuda()
self.cx1 = self.cx1.cuda()
self.hx2 = self.hx2.cuda()
self.cx2 = self.cx2.cuda()
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.relu3(x)
x = self.conv3_1(x)
x = self.relu3_1(x)
x = self.conv4(x)
x = self.relu4(x)
x = self.conv4_1(x)
x = self.relu4_1(x)
x = self.conv5(x)
x = self.relu5(x)
x = self.conv5_1(x)
x = self.relu5_1(x)
x = self.conv6(x)
x = x.view(x.size(0), 20 * 6 * 1024)
self.hx1, self.cx1 = self.lstm1(x, (self.hx1, self.cx1))
x = self.hx1
self.hx2, self.cx2 = self.lstm2(x, (self.hx2, self.cx2))
x = self.hx2
x = self.fc(x)
return x