-
Notifications
You must be signed in to change notification settings - Fork 2
/
SuperModel.py
102 lines (77 loc) · 3.3 KB
/
SuperModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
##############################################
# (c) Copyright 2018-2019 Kenza Tazi and Thomas Zhu
# This software is distributed under the terms of the GNU General Public
# Licence version 3 (GPLv3)
##############################################
import os
import numpy as np
import tensorflow as tf
import DataPreparation as dp
from CNN import CNN
from FFN import FFN
class SuperModel():
def __init__(self, name, FFN=None, CNN=None):
self.name = name
self.FFN = FFN
self.CNN = CNN
self._isLoaded = False
@property
def isLoaded(self):
self._isLoaded = (
self.FFN.isLoaded and self.CNN.isLoaded)
return(self._isLoaded)
@isLoaded.setter
def isLoaded(self, value):
self._isLoaded = value
def predict_file(self, Sreference):
ffninputs = dp.getinputsFFN(
Sreference, input_type=22) # include indices
predictions1 = self.FFN.Predict(ffninputs)[:, 0]
labels1 = self.FFN.model.predict_label(ffninputs)[:, 0]
# boolean mask of bad predictions
bad = abs(predictions1 - 0.5) < 0.25
goodindices = np.where(bad == False)[0]
badindices = np.where(bad == True)[0]
cnninputs = dp.getinputsCNN(Sreference, badindices)
cnninputs = dp.star_padding(cnninputs)
# Feeding all of the inputs at once can cause a memory error
# Instead split into chunks of 10,000
chunkedcnninputs = [cnninputs[i: i + 10000]
for i in range(0, len(cnninputs), 10000)]
predictions2 = []
labels2 = []
for i in range(len(chunkedcnninputs)):
predictions2.extend(self.CNN.model.predict(
chunkedcnninputs[i])[:, 0])
labels2.extend(self.CNN.model.predict_label(
chunkedcnninputs[i])[:, 0])
finallabels = np.zeros(7200000)
finallabels[goodindices] = labels1[goodindices]
finallabels[badindices] = labels2
finalpredictions = np.zeros(7200000)
finalpredictions[goodindices] = predictions1[goodindices]
finalpredictions[badindices] = predictions2
finallabels = finallabels.reshape((2400, 3000))
finalpredictions = finalpredictions.reshape((2400, 3000))
return finallabels, finalpredictions
def Save(self):
os.mkdir('Models/' + self.name)
self.FFN.Save('Models/' + self.name + '/FFN_' + self.FFN.name)
self.CNN.Save('Models/' + self.name + '/CNN_' + self.CNN.name)
with open('Models/' + self.name + '/Info.txt', 'w') as file:
file.write('FFN: ' + self.FFN.name + '\n')
file.write('CNN: ' + self.CNN.name)
def Load(self):
try:
with open('Models/' + self.name + '/Info.txt', 'r') as file:
settings = file.readlines()
if len(settings) == 2:
self.FFNname = settings[0].strip().split(' ')[1]
self.CNNname = settings[1].strip().split(' ')[1]
except FileNotFoundError:
raise Exception('File does not exist')
self.FFN = FFN(self.FFNname)
self.FFN.Load('Models/' + self.name + '/FFN_' + self.FFN.name)
tf.reset_default_graph()
self.CNN = CNN(self.CNNname)
self.CNN.Load('Models/' + self.name + '/CNN_' + self.CNN.name)