-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot_utils.py
58 lines (49 loc) · 2.08 KB
/
plot_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from matplotlib import pyplot as plt
import numpy as np
def plot_model_history(history, timestamp):
# list all data in history
print(history.history.keys())
# summarize history for accuracy
plt.figure()
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.savefig("results-cnn/" + timestamp + "_model_accuracy.png")
#plt.show()
# summarize history for loss
plt.figure()
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.savefig("results-cnn/" + timestamp + "_model_loss.png")
plt.close()
#plt.show()
def plot_result_examples(model, X_test, y_test, img_rows, img_cols, timestamp):
"""
The predict_classes function outputs the highest probability class
according to the trained classifier for each input example.
:return:
"""
predicted_classes = model.predict_classes(X_test)
# Check which items we got right / wrong
correct_indices = np.nonzero(predicted_classes == y_test)[0]
incorrect_indices = np.nonzero(predicted_classes != y_test)[0]
plt.figure()
for i, correct in enumerate(correct_indices[:9]):
plt.subplot(3, 3, i + 1)
plt.imshow(X_test[correct].reshape(img_rows, img_cols), cmap='gray', interpolation='none')
plt.title("Predicted {}, Class {}".format(predicted_classes[correct], y_test[correct]))
plt.savefig("results-cnn/" + timestamp + "_predicted_class_correct.pdf")
plt.figure()
for i, incorrect in enumerate(incorrect_indices[:9]):
plt.subplot(3, 3, i + 1)
plt.imshow(X_test[incorrect].reshape(img_rows, img_cols), cmap='gray', interpolation='none')
plt.title("Predicted {}, Class {}".format(predicted_classes[incorrect], y_test[incorrect]))
plt.savefig("results-cnn/" + timestamp + "_predicted_class_incorrect.pdf")
#plt.show()