-
Notifications
You must be signed in to change notification settings - Fork 33
/
mxresnet.py
116 lines (93 loc) · 4.07 KB
/
mxresnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#FastAI's XResnet modified to use Mish activation function, MXResNet
#https://github.com/fastai/fastai/blob/master/fastai/vision/models/xresnet.py
#modified by lessw2020 - github: https://github.com/lessw2020/mish
import torch.nn as nn
import torch,math,sys
import torch.utils.model_zoo as model_zoo
from functools import partial
#from ...torch_core import Module
from fastai.torch_core import Module
import torch.nn.functional as F #(uncomment if needed,but you likely already have it)
class Mish(nn.Module):
def __init__(self):
super().__init__()
print("Mish activation loaded...")
def forward(self, x):
x = x *( torch.tanh(F.softplus(x)))
return x
# or: ELU+init (a=0.54; gain=1.55)
act_fn = Mish()#nn.ReLU(inplace=True)
__all__ = ['MXResNet', 'mxresnet18', 'mxresnet34', 'mxresnet50', 'mxresnet101', 'mxresnet152']
# or: ELU+init (a=0.54; gain=1.55)
act_fn = Mish() #nn.ReLU(inplace=True)
class Flatten(Module):
def forward(self, x): return x.view(x.size(0), -1)
def init_cnn(m):
if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
for l in m.children(): init_cnn(l)
def conv(ni, nf, ks=3, stride=1, bias=False):
return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)
def noop(x): return x
def conv_layer(ni, nf, ks=3, stride=1, zero_bn=False, act=True):
bn = nn.BatchNorm2d(nf)
nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
layers = [conv(ni, nf, ks, stride=stride), bn]
if act: layers.append(act_fn)
return nn.Sequential(*layers)
class ResBlock(Module):
def __init__(self, expansion, ni, nh, stride=1):
nf,ni = nh*expansion,ni*expansion
layers = [conv_layer(ni, nh, 3, stride=stride),
conv_layer(nh, nf, 3, zero_bn=True, act=False)
] if expansion == 1 else [
conv_layer(ni, nh, 1),
conv_layer(nh, nh, 3, stride=stride),
conv_layer(nh, nf, 1, zero_bn=True, act=False)
]
self.convs = nn.Sequential(*layers)
# TODO: check whether act=True works better
self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False)
self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)
def forward(self, x): return act_fn(self.convs(x) + self.idconv(self.pool(x)))
def filt_sz(recep): return min(64, 2**math.floor(math.log2(recep*0.75)))
class MXResNet(nn.Sequential):
def __init__(self, expansion, layers, c_in=3, c_out=1000):
stem = []
sizes = [c_in,32,64,64] #modified per Grankin
for i in range(3):
stem.append(conv_layer(sizes[i], sizes[i+1], stride=2 if i==0 else 1))
#nf = filt_sz(c_in*9)
#stem.append(conv_layer(c_in, nf, stride=2 if i==1 else 1))
#c_in = nf
block_szs = [64//expansion,64,128,256,512]
blocks = [self._make_layer(expansion, block_szs[i], block_szs[i+1], l, 1 if i==0 else 2)
for i,l in enumerate(layers)]
super().__init__(
*stem,
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
*blocks,
nn.AdaptiveAvgPool2d(1), Flatten(),
nn.Linear(block_szs[-1]*expansion, c_out),
)
init_cnn(self)
def _make_layer(self, expansion, ni, nf, blocks, stride):
return nn.Sequential(
*[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1)
for i in range(blocks)])
def mxresnet(expansion, n_layers, name, pretrained=False, **kwargs):
model = MXResNet(expansion, n_layers, **kwargs)
if pretrained:
#model.load_state_dict(model_zoo.load_url(model_urls[name]))
print("No pretrained yet for MXResNet")
return model
me = sys.modules[__name__]
for n,e,l in [
[ 18 , 1, [2,2,2 ,2] ],
[ 34 , 1, [3,4,6 ,3] ],
[ 50 , 4, [3,4,6 ,3] ],
[ 101, 4, [3,4,23,3] ],
[ 152, 4, [3,8,36,3] ],
]:
name = f'mxresnet{n}'
setattr(me, name, partial(mxresnet, expansion=e, n_layers=l, name=name))