-
Notifications
You must be signed in to change notification settings - Fork 2
/
nn.py
103 lines (77 loc) · 3.23 KB
/
nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
class Conv2d(nn.Module):
def __init__(self, input_channels, output_channels, kernel_size, stride, padding, dilation=1, activation=None, bias=True):
super(Conv2d, self).__init__()
self.activation = activation
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation, bias=bias)
def forward(self, x):
h = self.conv(x)
if self.activation is None:
out = h
else:
out = self.activation(h)
return out
# =======================================================================================================================
class Conv2dBN(nn.Module):
def __init__(self, input_channels, output_channels, kernel_size, stride, padding, dilation=1, activation=None):
super(Conv2dBN, self).__init__()
self.activation = activation
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation)
self.bn = nn.BatchNorm2d(output_channels)
def forward(self, x):
h = self.conv(x)
h = self.bn(h)
if self.activation is None:
out = h
else:
out = self.activation(h)
return out
class GatedConv2d(nn.Module):
def __init__(self, input_channels, output_channels, kernel_size, stride, padding, dilation=1, activation=None):
super(GatedConv2d, self).__init__()
self.activation = activation
self.sigmoid = nn.Sigmoid()
self.h = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation)
self.g = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation)
def forward(self, x):
if self.activation is None:
h = self.h(x)
else:
h = self.activation( self.h( x ) )
g = self.sigmoid( self.g( x ) )
return h * g
class GatedDense(nn.Module):
def __init__(self, input_size, output_size, activation=None):
super(GatedDense, self).__init__()
self.activation = activation
self.sigmoid = nn.Sigmoid()
self.h = nn.Linear(input_size, output_size)
self.g = nn.Linear(input_size, output_size)
def forward(self, x):
h = self.h(x)
if self.activation is not None:
h = self.activation( self.h( x ) )
g = self.sigmoid( self.g( x ) )
return h * g
class FullyConnected(nn.Sequential):
"""
Fully connected multi-layer network with ReLU activations.
"""
def __init__(self, sizes, layer_activation = nn.ReLU(), final_activation=None, batch_normalization = False):
layers = []
for in_size, out_size in zip(sizes, sizes[1:]):
layers.append(nn.Linear(in_size, out_size))
if batch_normalization:
layers.append(nn.BatchNorm1d(out_size))
layers.append(layer_activation)
layers.pop(-1)
if final_activation is not None:
layers.append(final_activation)
super().__init__(*layers)
def append(self, layer):
assert isinstance(layer, nn.Module)
self.add_module(str(len(self)), layer)