-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
31 lines (23 loc) · 1.11 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import unittest
from numpy.random import seed
import neural_net
from examples.mnist.mnist_loader import MnistLoader
MNIST_DATA_FOLDER = "../examples/mnist/data"
VALIDATION_SET_SIZE = 10000
HIDDEN_LAYER_SIZE = 30
EPOCHS = 3
BATCH_SIZE = 10
LEARNING_RATE = 3.0
class TestNeuralNet(unittest.TestCase):
def test_percentage_correct(self) -> None:
"""Tests that the network achieves the expected percentage correct after training on the MNIST dataset."""
seed(0)
mnist_loader = MnistLoader(MNIST_DATA_FOLDER)
training_inputs, training_outputs, validation_inputs, validation_outputs, test_inputs, test_outputs = \
mnist_loader.load_data(VALIDATION_SET_SIZE)
input_neurons = training_inputs[0].shape[0]
output_neurons = training_outputs[0].shape[0]
net = neural_net.Network([input_neurons, HIDDEN_LAYER_SIZE, output_neurons])
net.train(training_inputs, training_outputs, EPOCHS, BATCH_SIZE, LEARNING_RATE, test_inputs, test_outputs)
percent_correct = net.percentage_correct(test_inputs, test_outputs)
assert round(percent_correct, 1) == 92.8