forked from buaabai/Ternary-Weights-Network
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
105 lines (92 loc) · 3.39 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Apr 28 13:34:27 2018
@author: bai
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def Ternarize(tensor):
output = torch.zeros(tensor.size())
delta = Delta(tensor)
alpha = Alpha(tensor,delta)
for i in range(tensor.size()[0]):
for w in tensor[i].view(1,-1):
pos_one = (w > delta[i]).type(torch.FloatTensor)
neg_one = -1 * (w < -delta[i]).type(torch.FloatTensor)
out = torch.add(pos_one,neg_one).view(tensor.size()[1:])
output[i] = torch.add(output[i],torch.mul(out,alpha[i]))
return output
def Alpha(tensor,delta):
Alpha = []
for i in range(tensor.size()[0]):
count = 0
abssum = 0
absvalue = tensor[i].view(1,-1).abs()
for w in absvalue:
truth_value = w > delta[i] #print to see
count = truth_value.sum()
abssum = torch.matmul(absvalue,truth_value.type(torch.FloatTensor).view(-1,1))
Alpha.append(abssum/count)
alpha = Alpha[0]
for i in range(len(Alpha) - 1):
alpha = torch.cat((alpha,Alpha[i+1]))
return alpha
def Delta(tensor):
n = tensor[0].nelement()
if(len(tensor.size()) == 4): #convolution layer
delta = 0.7 * tensor.norm(1,3).sum(2).sum(1).div(n)
elif(len(tensor.size()) == 2): #fc layer
delta = 0.7 * tensor.norm(1,1).div(n)
return delta
class TernaryLinear(nn.Linear):
def __init__(self,*args,**kwargs):
super(TernaryLinear,self).__init__(*args,**kwargs)
def forward(self,input):
self.weight.data = Ternarize(self.weight.data)
out = F.linear(input,self.weight,self.bias)
return out
class TernaryConv2d(nn.Conv2d):
def __init__(self,*args,**kwargs):
super(TernaryConv2d,self).__init__(*args,**kwargs)
def forward(self,input):
self.weight.data = Ternarize(self.weight.data)
out = F.conv2d(input, self.weight, self.bias, self.stride,self.padding, self.dilation, self.groups)
return out
class LeNet5_T(nn.Module):
def __init__(self):
super(LeNet5_T,self).__init__()
self.conv1 = TernaryConv2d(1,32,kernel_size = 5)
self.bn_conv1 = nn.BatchNorm2d(32)
self.conv2 = TernaryConv2d(32,64,kernel_size = 5)
self.bn_conv2 = nn.BatchNorm2d(64)
self.fc1 = TernaryLinear(1024,512)
self.fc2 = TernaryLinear(512,10)
def forward(self,x):
x = self.conv1(x)
x = F.relu(F.max_pool2d(self.bn_conv1(x),2))
x = self.conv2(x)
x = F.relu(F.max_pool2d(self.bn_conv2(x),2))
x = x.view(-1,1024)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5,self).__init__()
self.conv1 = nn.Conv2d(1,32,kernel_size = 5)
self.bn_conv1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32,64,kernel_size = 5)
self.bn_conv2 = nn.BatchNorm2d(64)
self.fc1 = nn.Linear(1024,512)
self.fc2 = nn.Linear(512,10)
def forward(self,x):
x = self.conv1(x)
x = F.relu(F.max_pool2d(self.bn_conv1(x),2))
x = self.conv2(x)
x = F.relu(F.max_pool2d(self.bn_conv2(x),2))
x = x.view(-1,1024)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x