Skip to content

Commit b5e7f74

Browse files
committed
changing reshape layer
1 parent a28a9be commit b5e7f74

16 files changed

+1173
-78
lines changed

CMakeLists.txt

+6
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,12 @@ add_library(neural-fortran
3838
src/nf/nf_layer_constructors_submodule.f90
3939
src/nf/nf_layer.f90
4040
src/nf/nf_layer_submodule.f90
41+
src/nf/nf_locally_connected_1d_submodule.f90
42+
src/nf/nf_locally_connected_1d.f90
4143
src/nf/nf_loss.f90
4244
src/nf/nf_loss_submodule.f90
45+
src/nf/nf_maxpool1d_layer.f90
46+
src/nf/nf_maxpool1d_layer_submodule.f90
4347
src/nf/nf_maxpool2d_layer.f90
4448
src/nf/nf_maxpool2d_layer_submodule.f90
4549
src/nf/nf_metrics.f90
@@ -51,6 +55,8 @@ add_library(neural-fortran
5155
src/nf/nf_random.f90
5256
src/nf/nf_reshape_layer.f90
5357
src/nf/nf_reshape_layer_submodule.f90
58+
src/nf/nf_reshape2d_layer.f90
59+
src/nf/nf_reshape2d_layer_submodule.f90
5460
src/nf/io/nf_io_binary.f90
5561
src/nf/io/nf_io_binary_submodule.f90
5662
)

example/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
foreach(execid
22
cnn_mnist
3+
cnn_mnist_1d
34
dense_mnist
45
get_set_network_params
56
network_parameters

example/cnn_mnist_1d.f90

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
program cnn_mnist
2+
3+
use nf, only: network, sgd, &
4+
input, conv2d, maxpool1d, maxpool2d, flatten, dense, reshape, reshape2d, locally_connected_1d, &
5+
load_mnist, label_digits, softmax, relu
6+
7+
implicit none
8+
9+
type(network) :: net
10+
11+
real, allocatable :: training_images(:,:), training_labels(:)
12+
real, allocatable :: validation_images(:,:), validation_labels(:)
13+
real, allocatable :: testing_images(:,:), testing_labels(:)
14+
integer :: n
15+
integer, parameter :: num_epochs = 10
16+
17+
call load_mnist(training_images, training_labels, &
18+
validation_images, validation_labels, &
19+
testing_images, testing_labels)
20+
21+
net = network([ &
22+
input(784), &
23+
reshape2d([28,28]), &
24+
locally_connected_1d(filters=8, kernel_size=3, activation=relu()), &
25+
maxpool1d(pool_size=2), &
26+
locally_connected_1d(filters=16, kernel_size=3, activation=relu()), &
27+
maxpool1d(pool_size=2), &
28+
dense(10, activation=softmax()) &
29+
])
30+
31+
call net % print_info()
32+
33+
epochs: do n = 1, num_epochs
34+
35+
call net % train( &
36+
training_images, &
37+
label_digits(training_labels), &
38+
batch_size=16, &
39+
epochs=1, &
40+
optimizer=sgd(learning_rate=0.003) &
41+
)
42+
43+
print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( &
44+
net, validation_images, label_digits(validation_labels)) * 100, ' %'
45+
46+
end do epochs
47+
48+
print '(a,f5.2,a)', 'Testing accuracy: ', &
49+
accuracy(net, testing_images, label_digits(testing_labels)) * 100, '%'
50+
51+
contains
52+
53+
real function accuracy(net, x, y)
54+
type(network), intent(in out) :: net
55+
real, intent(in) :: x(:,:), y(:,:)
56+
integer :: i, good
57+
good = 0
58+
do i = 1, size(x, dim=2)
59+
if (all(maxloc(net % predict(x(:,i))) == maxloc(y(:,i)))) then
60+
good = good + 1
61+
end if
62+
end do
63+
accuracy = real(good) / size(x, dim=2)
64+
end function accuracy
65+
66+
end program cnn_mnist
67+

src/nf.f90

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module nf
33
use nf_datasets_mnist, only: label_digits, load_mnist
44
use nf_layer, only: layer
55
use nf_layer_constructors, only: &
6-
conv2d, dense, flatten, input, maxpool2d, reshape
6+
conv2d, dense, flatten, input, maxpool1d, maxpool2d, reshape, reshape2d, locally_connected_1d
77
use nf_loss, only: mse, quadratic
88
use nf_metrics, only: corr, maxabs
99
use nf_network, only: network

0 commit comments

Comments
 (0)