1
+ program cnn_mnist
2
+
3
+ use nf, only: network, sgd, &
4
+ input, conv2d, maxpool1d, maxpool2d, flatten, dense, reshape, reshape2d, locally_connected_1d, &
5
+ load_mnist, label_digits, softmax, relu
6
+
7
+ implicit none
8
+
9
+ type (network) :: net
10
+
11
+ real , allocatable :: training_images(:,:), training_labels(:)
12
+ real , allocatable :: validation_images(:,:), validation_labels(:)
13
+ real , allocatable :: testing_images(:,:), testing_labels(:)
14
+ integer :: n
15
+ integer , parameter :: num_epochs = 10
16
+
17
+ call load_mnist(training_images, training_labels, &
18
+ validation_images, validation_labels, &
19
+ testing_images, testing_labels)
20
+
21
+ net = network([ &
22
+ input(784 ), &
23
+ reshape2d([28 ,28 ]), &
24
+ locally_connected_1d(filters= 8 , kernel_size= 3 , activation= relu()), &
25
+ maxpool1d(pool_size= 2 ), &
26
+ locally_connected_1d(filters= 16 , kernel_size= 3 , activation= relu()), &
27
+ maxpool1d(pool_size= 2 ), &
28
+ dense(10 , activation= softmax()) &
29
+ ])
30
+
31
+ call net % print_info()
32
+
33
+ epochs: do n = 1 , num_epochs
34
+
35
+ call net % train( &
36
+ training_images, &
37
+ label_digits(training_labels), &
38
+ batch_size= 16 , &
39
+ epochs= 1 , &
40
+ optimizer= sgd(learning_rate= 0.003 ) &
41
+ )
42
+
43
+ print ' (a,i2,a,f5.2,a)' , ' Epoch ' , n, ' done, Accuracy: ' , accuracy( &
44
+ net, validation_images, label_digits(validation_labels)) * 100 , ' %'
45
+
46
+ end do epochs
47
+
48
+ print ' (a,f5.2,a)' , ' Testing accuracy: ' , &
49
+ accuracy(net, testing_images, label_digits(testing_labels)) * 100 , ' %'
50
+
51
+ contains
52
+
53
+ real function accuracy (net , x , y )
54
+ type (network), intent (in out ) :: net
55
+ real , intent (in ) :: x(:,:), y(:,:)
56
+ integer :: i, good
57
+ good = 0
58
+ do i = 1 , size (x, dim= 2 )
59
+ if (all (maxloc (net % predict(x(:,i))) == maxloc (y(:,i)))) then
60
+ good = good + 1
61
+ end if
62
+ end do
63
+ accuracy = real (good) / size (x, dim= 2 )
64
+ end function accuracy
65
+
66
+ end program cnn_mnist
67
+
0 commit comments