-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassifier.py
130 lines (106 loc) · 4.63 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
from datetime import datetime
from sklearn.datasets import load_files
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn import svm
from sklearn.externals import joblib
from sklearn.metrics import accuracy_score
#some properties.
#The extension name was added just for fun.
#It can be changed to anything you want.
#But ensure to make appropriate changes in the consumer API(spring boot) as well
all_models_path = r"D:\kaam\all_models"
ext = ".tupperware"
#This function is called when the training endpoint receives all the data from the user
#and starts training the machine learning algorithm
#'params' is a dictionary which contains the necessary values.
#It can be modified as needed..
def train_svm(params):
try:
all_of_it = load_files(params['dirPath'], shuffle=True, random_state=None)
print(params['dirPath'])
total = len(all_of_it.target)
print(total)
num = int(params['train_ratio'] * total)
print("num ",num)
train_data = all_of_it.data[:num]
validation_data = all_of_it.data[num:]
vect = TfidfVectorizer()
X_train_tf = vect.fit_transform(train_data)
clf = svm.SVC(decision_function_shape="ovo", C = params['C_provided'], kernel=params['kernel_provided'], gamma = params['gamma_provided'])
clf.fit(X_train_tf, all_of_it.target[:num])
modelFileName = str(datetime.now()).replace(':','-').replace(' ','-').replace('.','-') + ext
modelFilePath = os.path.join(all_models_path, modelFileName)
iskodumpkar = {"modelFile": clf, "vectFile": vect, "dirPath" = params['dirPath']}
joblib.dump(iskodumpkar,modelFilePath)
###Added code after the successful demo on the last day
###Uncomment the following lines to use the functionality
# X_val_tf = loadedVect.transform(validation_data)
# pred_values = clf.predict(X_val_tf)
# true_values = all_of_it.target[num:]
# val_accuracy = accuracy_score(true_values, pred_values)
###Return the validation accuracy after this
###Next is determining what the algorithm went wrong with
###And what did it classify correctly
# incorrect_indices=[]
# for i, x in enumerate(true_values):
# if not x == pred_values[i]:
# incorrect_indices.append(i)
# correct_indices = [i for i,j in enumerate(true_values) if i not in incorrect_indices]
# with open('incorrect_class_file.csv','wb') as f:
# f.write("Filename , ")
# f.write("True Class , ")
# f.write("Predicted Class\n")
# for i in incorrect_indices:
# f.write(all_of_it.filenames[num+i])
# f.write(",")
# f.write(all_of_it.target_names[true_values[i]])
# f.write(",")
# f.write(all_of_it.target_names[pred_values[i]])
# f.write("\n")
#
#
# with open('correct_class_file.csv','wb') as f:
# f.write("Filename , ")
# f.write("True Class , ")
# f.write("Predicted Class\n")
# for i in correct_indices:
# f.write(all_of_it.filenames[num+i])
# f.write(",")
# f.write(all_of_it.target_names[true_values[i]])
# f.write(",")
# f.write(all_of_it.target_names[pred_values[i]])
# f.write("\n")
### After the files have been created, return the filepath as well to the calling API
### So that it can serve the downloads
#If someone or something screws up(Shouldn't happen. Spring boot shall ensure that)
except:
print("kisine lolwa kiya!!!!!!!!!!!!!!!!!!!!!!!")
modelFilePath = ""
modelFileName = ""
return modelFileName
#This is function is called when the user hits the test endpoint and this is where the true testing happens.
def test_svm(saved_model, TEST_DIR):
#the saved model uploaded by the user
loadedModelDict = joblib.load(saved_model)
all_of_it = load_files(loadedModelDict['dirPath'], shuffle=True, random_state=None)
# all_of_it = load_files(r"D:\kaam\AdditionalParsed", shuffle=True, random_state=None)
# names = ["AoI", "MC"]
##DEPRECATED:
#In case of an empty model file
if not saved_model:
all_models_that_i_have = [(os.path.getmtime(fn), fn) for fn in os.scandir(all_models_path) if fn.name.endswith(ext)]
all_models_that_i_have.sort(reverse=True)
saved_model = all_models_that_i_have[0][1]
loadedModel = loadedModelDict['modelFile']
loadedVect = loadedModelDict['vectFile']
res = dict()
print("File:\tClassified as:")
for home,subdir,files in os.walk(TEST_DIR):
for file_ in files:
with open(os.path.join(TEST_DIR, file_)) as f:
# print(file_ + "\t" + all_of_it.target_names[int(loadedModel.predict(loadedVect.transform([f.read()])))])
res[file_] = all_of_it.target_names[int(loadedModel.predict(loadedVect.transform([f.read()])))]
# res[file_] = names[int(loadedModel.predict(loadedVect.transform([f.read()])))]
resulta = [{"file": i,"cat":"{}".format(j)} for i,j in res.items()]
return resulta