-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodels.py
85 lines (75 loc) · 2.46 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np
import tensorflow as tf
import dynamic_fixed_point as dfxp
class Model:
def __init__(self, bits, dropout, weight_decay, stochastic, training):
self.bits = bits
self.dropout = dropout
self.weight_decay = weight_decay
self.stochastic = stochastic
self.training = training
self.layers = self.get_layers()
def get_layers(self):
return []
def forward(self, X):
for layer in self.layers:
X = layer.forward(X)
logit = X
pred = tf.argmax(logit, axis=1, output_type=tf.int32)
return logit, pred
def grads_and_vars(self):
res = []
for layer in self.layers:
res += layer.grads_and_vars()
return res
def backward(self, grad):
for layer in reversed(self.layers):
grad = layer.backward(grad, self.stochastic)
return grad
def info(self):
return '\n'.join([layer.info() for layer in self.layers])
class MNIST_Model(Model):
def __init__(self, bits, dropout=0.5, weight_decay=0, stochastic=False, training=False):
super().__init__(bits, dropout, weight_decay, stochastic, training)
def get_layers(self):
return [
dfxp.Conv2d_q(
name='conv',
bits=self.bits,
training = self.training,
ksize=[5, 5, 1, 20],
strides=[1, 1, 1, 1],
padding='VALID',
weight_decay=self.weight_decay,
),
dfxp.BatchNorm_q(
name='batch_normolization',
bits=self.bits,
num_features=20,
training=self.training,
weight_decay=self.weight_decay,
),
dfxp.ReLU_q(),
dfxp.MaxPool_q(
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='VALID'
),
dfxp.Flatten_q(12*12*20),
dfxp.Dense_q(
name='dense1',
bits=self.bits,
training = self.training,
in_units=12*12*20,
units=100,
weight_decay=self.weight_decay,
),
dfxp.Dense_q(
name='dense2',
bits=self.bits,
training = self.training,
in_units=100,
units=10,
weight_decay=self.weight_decay,
)
]