-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_evolution.py
23 lines (16 loc) · 1.25 KB
/
mnist_evolution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from evolutionary_module import evolve_node_count
import mnist_io
import utils
import os
dataset = os.path.join(os.path.split(os.path.dirname(os.path.realpath(__file__)))[0], "Python\\datasets")
train_images = mnist_io.images_from_file(os.path.join(dataset, "train-images-idx3-ubyte\\train-images.idx3-ubyte"), 60000)
train_images = train_images.reshape(60000, 784).astype('float32')
train_images /= 255
train_labels = utils.jit_to_categorical(mnist_io.labels_from_file(os.path.join(dataset, "train-labels-idx1-ubyte\\train-labels.idx1-ubyte"), 60000), 10)
test_images = mnist_io.images_from_file(os.path.join(dataset, "t10k-images-idx3-ubyte\\t10k-images.idx3-ubyte"), 10000)
test_images = test_images.reshape(10000, 784).astype('float32')
test_images /= 255
test_labels = utils.jit_to_categorical(mnist_io.labels_from_file(os.path.join(dataset, "t10k-labels-idx1-ubyte/t10k-labels.idx1-ubyte"), 10000), 10)
with utils.OutSplit('mnist_evolution'):
evolved_network = evolve_node_count(train_images, train_labels, test_images, test_labels, utils.jit_categorical_compare, population_count=10, population_size=15, node_cap=1500, generations=50, target_accuracy=0.95, r=4)
print(evolved_network.connections, "\n", evolved_network.weights, "\n",evolved_network.learning_rate)