-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathrun_mnist_ae.py
38 lines (23 loc) · 1.08 KB
/
run_mnist_ae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
from sklearn.datasets import fetch_mldata
from sklearn.cross_validation import train_test_split
from sklearn.naive_bayes import MultinomialNB
from DenoisingAutoencoder import DenoisingAutoencoder
custom_data_home = os.path.join(os.path.split(__file__)[0], "data")
mnist = fetch_mldata('MNIST original', data_home=custom_data_home)
X, y = mnist.data / 255., mnist.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.25, random_state=42)
da = DenoisingAutoencoder(n_hidden=400, verbose=True, training_epochs=5)
da.fit(X_train)
X_train_latent = da.transform_latent_representation(X_train)
X_test_latent = da.transform_latent_representation(X_test)
clf = MultinomialNB()
# Fit the model
clf.fit(X_train_latent, y_train)
# Perform the predictions
y_predicted = clf.predict(X_test_latent)
from sklearn.metrics import accuracy_score
print "Accuracy = {} %".format(accuracy_score(y_test, y_predicted)*100)
from sklearn.metrics import classification_report
print "Classification Report \n {}".format(classification_report(y_test, y_predicted, labels=range(0,10)))