-
Notifications
You must be signed in to change notification settings - Fork 0
/
qjointvae.py
316 lines (251 loc) · 13.1 KB
/
qjointvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import os
import ad.constants
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras.layers import *
from qkeras import *
from qkeras.quantizers import *
from tensorflow.keras.regularizers import *
from tensorflow.keras.layers import GlobalAveragePooling2D
from typing import List
from ad import utils
from ad import metrics
from ad import layers
class QJointVAE(keras.Model):
""" Quantized Joint VAE model"""
def __init__(self,
continous_latent: int = 32,
discrete_latent: int = 16,
temperature: float = 50.,
alpha: float = 1.,
beta: float = 3e3,
eps_kl: float = 1e-7,
name = None,
**kwargs):
super().__init__(name = name)
#parameter assignment
self.continous_latent = continous_latent
self.temp = temperature #gumbel-softmax
self.discrete_latent = discrete_latent #gumbel-softmax
self.alpha = alpha
self.beta = beta
self.eps_kl = eps_kl # KL divergence between gumbel-softmax distribution
# build the encoder and decoder networks
self.encoder = self.build_encoder(**kwargs.pop('encoder', {}))
self.sampling = layers.joint_sampling(temp = temperature, name='joint_sampling', **kwargs)
self.decoder = self.build_decoder(latent_shape = ( self.continous_latent + self.discrete_latent, ) ,
**kwargs.pop('decoder', {}))
#implementing the loss fct
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
self.kl_cont_loss_tracker = keras.metrics.Mean(name="kl_cont_loss")
self.kl_disc_loss_tracker = keras.metrics.Mean(name="kl_disc_loss")
self.mse_tracker = ad.metrics.MSE(name="mse")
self.psnr_tracker = ad.metrics.PSNR(name='psnr')
self.ssim_tracker = ad.metrics.SSIM(name='ssim')
@property
def metrics(self) -> List[tf.keras.metrics.Metric]:
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_cont_loss_tracker,
self.kl_disc_loss_tracker ,
self.mse_tracker,
self.psnr_tracker,
self.ssim_tracker,
]
def call(self, x, **kwargs):
q, mean, var = self.encoder(x, **kwargs)
z = self.sampling([mean, var, q])
return self.decoder(z, **kwargs)
def build_encoder(self, input_shape: tuple, depths: List[int], filters: List[int],
kernel=3, groups=None,
qconv= quantized_bits(16,6,alpha=1),
activation = quantized_relu(16,6, negative_slope=0.25),
qdense = quantized_bits(16,6,alpha=1),
**kwargs) -> tf.keras.Model:
"""Building the encoder architecture for the variational autoencoder.
The final encoding dimension can be chosen."""
assert len(depths) == len(filters)
images = Input(shape = input_shape, name='image')
x = images
for j, depth in enumerate(depths):
x = QConv2D(filters=filters[j], kernel_size=kernel, strides=2,
groups=groups if j > 0 else 1,
padding='same',
kernel_quantizer=qconv,
bias_quantizer=qconv,
# kernel_regularizer=l2(0.001) ,
**kwargs, name=f'dconv_b{j}')(x)
x = tfa.layers.InstanceNormalization(name=f'in_b{j}')(x)
# x = BatchNormalization(center = False, scale = False, name=f'bn-b{j}')(x)
x = QActivation(activation, name=f'activ_b{j}')(x)
# add residual blocks
for i in range(depth):
r = x # residual
x = QConv2D(filters=filters[j], kernel_size=kernel, strides=1,
groups=groups,
kernel_quantizer=qconv,
bias_quantizer=qconv,
padding='same',
# kernel_regularizer=l2(0.001) ,
**kwargs, name=f'conv1_b{j}_{i}')(x)
x = tfa.layers.InstanceNormalization(name=f'in1_b{j}_{i}')(x)
# x = BatchNormalization(center = False, scale = False, name=f'bn1-b{j}_{i}')(x)
x = QActivation(activation, name=f'activ1_b{j}_{i}')(x)
x = QConv2D(filters=filters[j], kernel_size=kernel, strides=1,
groups=groups,
kernel_quantizer=qconv,
bias_quantizer=qconv,
padding='same',
# kernel_regularizer=l2(0.001) ,
**kwargs, name=f'conv2_b{j}_{i}')(x)
x = tfa.layers.InstanceNormalization(name=f'in2_b{j}_{i}')(x)
# x= BatchNormalization(center = False, scale = False, name=f'bn2-b{j}_{i}')(x)
x = QActivation(activation, name=f'activ2_b{j}_{i}')(x)
x = Add(name=f'add_b{j}_{i}')([x, r])
x = QConv2D(filters=2, kernel_size=kernel, strides=1,
groups=2,
kernel_quantizer=qconv,
bias_quantizer=qconv,
padding='same',
# kernel_regularizer=l2(0.001) ,
**kwargs, name=f'conv_fin')(x)
x = tfa.layers.InstanceNormalization(name=f'in_fin')(x)
z = Flatten()(x)
q = QDense(units = self.discrete_latent, kernel_quantizer=qdense,
bias_quantizer=qdense, name='z_categorical')(z)
encoded_mean = QDense(units = self.continous_latent, kernel_quantizer=qdense,
bias_quantizer=qdense, use_bias=True, name='z_mean')(z)
encoded_var = QDense(units = self.continous_latent, kernel_quantizer=qdense,
bias_quantizer=qdense, use_bias = True, name='z_var')(z)
return tf.keras.Model(inputs = images, outputs=[q, encoded_mean, encoded_var], name='Res-QEncoder')
def build_decoder(self, latent_shape: tuple, depths: List[int], filters: List[int],
crop: tuple, activation=tf.nn.relu6, kernel=3, size=(5, 4, 256),
out_channels=1, groups=None, **kwargs) -> tf.keras.Model:
"""
## Build the decoder
"""
assert len(depths) == len(filters)
latents = Input(shape=latent_shape, name='z')
if len(latent_shape) == 1:
x = ad.layers.SpatialBroadcast(width=size[1], height=size[0], name='spatial-broadcast')(latents)
x.set_shape((None, size[0], size[1], self.continous_latent + self. discrete_latent + 2))
x = ad.layers.ConvLayer(filters=size[-1], kernel=kernel, name='conv-expand',
activation=activation, **kwargs)(x)
else:
x = Reshape((5 * 4 * 256))(latents)
for j, depth in enumerate(depths):
x = ad.layers.UpConvLayer(filters=filters[j], kernel=kernel, **kwargs,
activation=activation, name=f'up_conv-b{j}')(x)
# add residual blocks
for i in range(depth):
r = x # residual
x = ad.layers.ConvLayer(filters=filters[j], kernel=kernel, **kwargs, groups=groups,
activation=activation, name=f'conv1-b{j}_{i}')(x)
x = ad.layers.ConvLayer(filters=filters[j], kernel=kernel, **kwargs, groups=groups,
activation=activation, name=f'conv2-b{j}_{i}')(x)
x = Add(name=f'add-b{j}_{i}')([x, r])
# reconstruction
reco = CenterCrop(*crop, name='crop')(x)
reco = Conv2D(filters=int(out_channels), kernel_size=kernel, padding='same',
activation=tf.nn.sigmoid, name='conv-reco')(reco)
return tf.keras.Model(inputs=latents, outputs=reco, name='Res-Decoder')
@tf.function
def train_step(self, data):
with tf.GradientTape() as tape:
q, z_mean, z_log_var = self.encoder(data, training=True)
z = self.sampling([z_mean, z_log_var, q])
reconstruction = self.decoder(z, training=True)
#######################
####### RECO ######
#######################
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
) * self.alpha
)
#######################
####### KLD-LOSS ######
#######################
##### kl_continous #####
kl_cont_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_cont_loss = tf.reduce_sum(kl_cont_loss, axis=1)
kl_cont_loss = tf.reduce_mean(self.beta * kl_cont_loss, axis = 0)
##### kl_categorical #####
q_p = tf.nn.softmax(q, axis=-1) # Convert the categorical codes into probabilities
# Entropy of the logits
h1 = q_p * tf.math.log(q_p + self.eps_kl)
# Cross entropy with the categorical distribution
h2 = q_p * tf.math.log(1. / self.discrete_latent + self.eps_kl)
kl_disc_loss = tf.reduce_mean(tf.reduce_sum(h1- h2 , axis = 1 ) * self.beta, axis = 0)
# kl_disc_loss = abs(kl_disc_loss)
total_loss = reconstruction_loss + kl_cont_loss + kl_disc_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_cont_loss_tracker.update_state(kl_cont_loss)
self.kl_disc_loss_tracker.update_state(kl_disc_loss)
self.mse_tracker.update_state(reconstruction, data)
self.psnr_tracker.update_state(reconstruction, data)
self.ssim_tracker.update_state(reconstruction, data)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_cont_loss": self.kl_cont_loss_tracker.result(),
"kl_disc_loss": self.kl_disc_loss_tracker.result(),
"mse": self.mse_tracker.result(),
"psnr": self.psnr_tracker.result(),
"ssim": self.ssim_tracker.result(),
}
@tf.function
def test_step(self, data):
q, z_mean, z_log_var = self.encoder(data, training=False)
z = self.sampling([z_mean, z_log_var, q])
reconstruction = self.decoder(z, training=False)
#######################
####### RECO ######
#######################
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
) * self.alpha
)
#######################
####### KLD-LOSS ######
#######################
##### kl_continous #####
kl_cont_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_cont_loss = tf.reduce_sum(kl_cont_loss, axis=1)
kl_cont_loss = tf.reduce_mean(self.beta * kl_cont_loss, axis = 0)
##### kl_categorical #####
q_p = tf.nn.softmax(q, axis=-1) # Convert the categorical codes into probabilities
# Entropy of the logits
h1 = q_p * tf.math.log(q_p + self.eps_kl)
# Cross entropy with the categorical distribution
h2 = q_p * tf.math.log(1. / self.discrete_latent + self.eps_kl)
kl_disc_loss = tf.reduce_mean(tf.reduce_sum(h1- h2 , axis = 1 ) * self.beta, axis = 0)
# kl_disc_loss = abs(kl_disc_loss)
total_loss = reconstruction_loss + kl_cont_loss + kl_disc_loss
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_cont_loss_tracker.update_state(kl_cont_loss)
self.kl_disc_loss_tracker.update_state(kl_disc_loss)
self.mse_tracker.update_state(reconstruction, data)
self.psnr_tracker.update_state(reconstruction, data)
self.ssim_tracker.update_state(reconstruction, data)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_cont_loss": self.kl_cont_loss_tracker.result(),
"kl_disc_loss": self.kl_disc_loss_tracker.result(),
"mse": self.mse_tracker.result(),
"psnr": self.psnr_tracker.result(),
"ssim": self.ssim_tracker.result(),
}
def summary(self):
self.encoder.summary()
self.decoder.summary()