From b5e7f74a2bec11e6c1737f6b4390a16f73369be2 Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Sun, 16 Feb 2025 19:32:29 +0100 Subject: [PATCH 01/26] changing reshape layer --- CMakeLists.txt | 6 + example/CMakeLists.txt | 1 + example/cnn_mnist_1d.f90 | 67 ++++ src/nf.f90 | 2 +- src/nf/nf_activation.f90 | 306 ++++++++++++++----- src/nf/nf_datasets_mnist_submodule.f90 | 6 +- src/nf/nf_layer_constructors.f90 | 63 +++- src/nf/nf_layer_constructors_submodule.f90 | 70 +++++ src/nf/nf_layer_submodule.f90 | 96 +++++- src/nf/nf_locally_connected_1d.f90 | 119 ++++++++ src/nf/nf_locally_connected_1d_submodule.f90 | 211 +++++++++++++ src/nf/nf_maxpool1d_layer.f90 | 69 +++++ src/nf/nf_maxpool1d_layer_submodule.f90 | 93 ++++++ src/nf/nf_network_submodule.f90 | 15 +- src/nf/nf_reshape2d_layer.f90 | 77 +++++ src/nf/nf_reshape2d_layer_submodule.f90 | 50 +++ 16 files changed, 1173 insertions(+), 78 deletions(-) create mode 100644 example/cnn_mnist_1d.f90 create mode 100644 src/nf/nf_locally_connected_1d.f90 create mode 100644 src/nf/nf_locally_connected_1d_submodule.f90 create mode 100644 src/nf/nf_maxpool1d_layer.f90 create mode 100644 src/nf/nf_maxpool1d_layer_submodule.f90 create mode 100644 src/nf/nf_reshape2d_layer.f90 create mode 100644 src/nf/nf_reshape2d_layer_submodule.f90 diff --git a/CMakeLists.txt b/CMakeLists.txt index 1a0a1be4..64128d7f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,8 +38,12 @@ add_library(neural-fortran src/nf/nf_layer_constructors_submodule.f90 src/nf/nf_layer.f90 src/nf/nf_layer_submodule.f90 + src/nf/nf_locally_connected_1d_submodule.f90 + src/nf/nf_locally_connected_1d.f90 src/nf/nf_loss.f90 src/nf/nf_loss_submodule.f90 + src/nf/nf_maxpool1d_layer.f90 + src/nf/nf_maxpool1d_layer_submodule.f90 src/nf/nf_maxpool2d_layer.f90 src/nf/nf_maxpool2d_layer_submodule.f90 src/nf/nf_metrics.f90 @@ -51,6 +55,8 @@ add_library(neural-fortran src/nf/nf_random.f90 src/nf/nf_reshape_layer.f90 src/nf/nf_reshape_layer_submodule.f90 + src/nf/nf_reshape2d_layer.f90 + src/nf/nf_reshape2d_layer_submodule.f90 src/nf/io/nf_io_binary.f90 src/nf/io/nf_io_binary_submodule.f90 ) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 28cf71a7..7632909e 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -1,5 +1,6 @@ foreach(execid cnn_mnist + cnn_mnist_1d dense_mnist get_set_network_params network_parameters diff --git a/example/cnn_mnist_1d.f90 b/example/cnn_mnist_1d.f90 new file mode 100644 index 00000000..f8b50ae5 --- /dev/null +++ b/example/cnn_mnist_1d.f90 @@ -0,0 +1,67 @@ +program cnn_mnist + + use nf, only: network, sgd, & + input, conv2d, maxpool1d, maxpool2d, flatten, dense, reshape, reshape2d, locally_connected_1d, & + load_mnist, label_digits, softmax, relu + + implicit none + + type(network) :: net + + real, allocatable :: training_images(:,:), training_labels(:) + real, allocatable :: validation_images(:,:), validation_labels(:) + real, allocatable :: testing_images(:,:), testing_labels(:) + integer :: n + integer, parameter :: num_epochs = 10 + + call load_mnist(training_images, training_labels, & + validation_images, validation_labels, & + testing_images, testing_labels) + + net = network([ & + input(784), & + reshape2d([28,28]), & + locally_connected_1d(filters=8, kernel_size=3, activation=relu()), & + maxpool1d(pool_size=2), & + locally_connected_1d(filters=16, kernel_size=3, activation=relu()), & + maxpool1d(pool_size=2), & + dense(10, activation=softmax()) & + ]) + + call net % print_info() + + epochs: do n = 1, num_epochs + + call net % train( & + training_images, & + label_digits(training_labels), & + batch_size=16, & + epochs=1, & + optimizer=sgd(learning_rate=0.003) & + ) + + print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( & + net, validation_images, label_digits(validation_labels)) * 100, ' %' + + end do epochs + + print '(a,f5.2,a)', 'Testing accuracy: ', & + accuracy(net, testing_images, label_digits(testing_labels)) * 100, '%' + + contains + + real function accuracy(net, x, y) + type(network), intent(in out) :: net + real, intent(in) :: x(:,:), y(:,:) + integer :: i, good + good = 0 + do i = 1, size(x, dim=2) + if (all(maxloc(net % predict(x(:,i))) == maxloc(y(:,i)))) then + good = good + 1 + end if + end do + accuracy = real(good) / size(x, dim=2) + end function accuracy + + end program cnn_mnist + \ No newline at end of file diff --git a/src/nf.f90 b/src/nf.f90 index b97d9e62..12e79736 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -3,7 +3,7 @@ module nf use nf_datasets_mnist, only: label_digits, load_mnist use nf_layer, only: layer use nf_layer_constructors, only: & - conv2d, dense, flatten, input, maxpool2d, reshape + conv2d, dense, flatten, input, maxpool1d, maxpool2d, reshape, reshape2d, locally_connected_1d use nf_loss, only: mse, quadratic use nf_metrics, only: corr, maxabs use nf_network, only: network diff --git a/src/nf/nf_activation.f90 b/src/nf/nf_activation.f90 index 309b43d2..caeab138 100644 --- a/src/nf/nf_activation.f90 +++ b/src/nf/nf_activation.f90 @@ -25,12 +25,14 @@ module nf_activation contains procedure(eval_1d_i), deferred :: eval_1d procedure(eval_1d_i), deferred :: eval_1d_prime + procedure(eval_2d_i), deferred :: eval_2d + procedure(eval_2d_i), deferred :: eval_2d_prime procedure(eval_3d_i), deferred :: eval_3d procedure(eval_3d_i), deferred :: eval_3d_prime procedure :: get_name - generic :: eval => eval_1d, eval_3d - generic :: eval_prime => eval_1d_prime, eval_3d_prime + generic :: eval => eval_1d, eval_2d, eval_3d + generic :: eval_prime => eval_1d_prime, eval_2d_prime, eval_3d_prime end type activation_function @@ -43,6 +45,13 @@ pure function eval_1d_i(self, x) result(res) real :: res(size(x)) end function eval_1d_i + pure function eval_2d_i(self, x) result(res) + import :: activation_function + class(activation_function), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + end function eval_2d_i + pure function eval_3d_i(self, x) result(res) import :: activation_function class(activation_function), intent(in) :: self @@ -57,6 +66,8 @@ end function eval_3d_i contains procedure :: eval_1d => eval_1d_elu procedure :: eval_1d_prime => eval_1d_elu_prime + procedure :: eval_2d => eval_2d_elu + procedure :: eval_2d_prime => eval_2d_elu_prime procedure :: eval_3d => eval_3d_elu procedure :: eval_3d_prime => eval_3d_elu_prime end type elu @@ -65,6 +76,8 @@ end function eval_3d_i contains procedure :: eval_1d => eval_1d_exponential procedure :: eval_1d_prime => eval_1d_exponential + procedure :: eval_2d => eval_2d_exponential + procedure :: eval_2d_prime => eval_2d_exponential procedure :: eval_3d => eval_3d_exponential procedure :: eval_3d_prime => eval_3d_exponential end type exponential @@ -73,6 +86,8 @@ end function eval_3d_i contains procedure :: eval_1d => eval_1d_gaussian procedure :: eval_1d_prime => eval_1d_gaussian_prime + procedure :: eval_2d => eval_2d_gaussian + procedure :: eval_2d_prime => eval_2d_gaussian_prime procedure :: eval_3d => eval_3d_gaussian procedure :: eval_3d_prime => eval_3d_gaussian_prime end type gaussian @@ -81,6 +96,8 @@ end function eval_3d_i contains procedure :: eval_1d => eval_1d_linear procedure :: eval_1d_prime => eval_1d_linear_prime + procedure :: eval_2d => eval_2d_linear + procedure :: eval_2d_prime => eval_2d_linear_prime procedure :: eval_3d => eval_3d_linear procedure :: eval_3d_prime => eval_3d_linear_prime end type linear @@ -89,6 +106,8 @@ end function eval_3d_i contains procedure :: eval_1d => eval_1d_relu procedure :: eval_1d_prime => eval_1d_relu_prime + procedure :: eval_2d => eval_2d_relu + procedure :: eval_2d_prime => eval_2d_relu_prime procedure :: eval_3d => eval_3d_relu procedure :: eval_3d_prime => eval_3d_relu_prime end type relu @@ -98,6 +117,8 @@ end function eval_3d_i contains procedure :: eval_1d => eval_1d_leaky_relu procedure :: eval_1d_prime => eval_1d_leaky_relu_prime + procedure :: eval_2d => eval_2d_leaky_relu + procedure :: eval_2d_prime => eval_2d_leaky_relu_prime procedure :: eval_3d => eval_3d_leaky_relu procedure :: eval_3d_prime => eval_3d_leaky_relu_prime end type leaky_relu @@ -106,6 +127,8 @@ end function eval_3d_i contains procedure :: eval_1d => eval_1d_sigmoid procedure :: eval_1d_prime => eval_1d_sigmoid_prime + procedure :: eval_2d => eval_2d_sigmoid + procedure :: eval_2d_prime => eval_2d_sigmoid_prime procedure :: eval_3d => eval_3d_sigmoid procedure :: eval_3d_prime => eval_3d_sigmoid_prime end type sigmoid @@ -114,6 +137,8 @@ end function eval_3d_i contains procedure :: eval_1d => eval_1d_softmax procedure :: eval_1d_prime => eval_1d_softmax_prime + procedure :: eval_2d => eval_2d_softmax + procedure :: eval_2d_prime => eval_2d_softmax_prime procedure :: eval_3d => eval_3d_softmax procedure :: eval_3d_prime => eval_3d_softmax_prime end type softmax @@ -122,6 +147,8 @@ end function eval_3d_i contains procedure :: eval_1d => eval_1d_softplus procedure :: eval_1d_prime => eval_1d_softplus_prime + procedure :: eval_2d => eval_2d_softplus + procedure :: eval_2d_prime => eval_2d_softplus_prime procedure :: eval_3d => eval_3d_softplus procedure :: eval_3d_prime => eval_3d_softplus_prime end type softplus @@ -130,6 +157,8 @@ end function eval_3d_i contains procedure :: eval_1d => eval_1d_step procedure :: eval_1d_prime => eval_1d_step_prime + procedure :: eval_2d => eval_2d_step + procedure :: eval_2d_prime => eval_2d_step_prime procedure :: eval_3d => eval_3d_step procedure :: eval_3d_prime => eval_3d_step_prime end type step @@ -138,6 +167,8 @@ end function eval_3d_i contains procedure :: eval_1d => eval_1d_tanh procedure :: eval_1d_prime => eval_1d_tanh_prime + procedure :: eval_2d => eval_2d_tanh + procedure :: eval_2d_prime => eval_2d_tanh_prime procedure :: eval_3d => eval_3d_tanh procedure :: eval_3d_prime => eval_3d_tanh_prime end type tanhf @@ -147,14 +178,16 @@ end function eval_3d_i contains procedure :: eval_1d => eval_1d_celu procedure :: eval_1d_prime => eval_1d_celu_prime + procedure :: eval_2d => eval_2d_celu + procedure :: eval_2d_prime => eval_2d_celu_prime procedure :: eval_3d => eval_3d_celu procedure :: eval_3d_prime => eval_3d_celu_prime end type celu contains + ! ELU Activation Functions pure function eval_1d_elu(self, x) result(res) - ! Exponential Linear Unit (ELU) activation function. class(elu), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) @@ -166,8 +199,6 @@ pure function eval_1d_elu(self, x) result(res) end function eval_1d_elu pure function eval_1d_elu_prime(self, x) result(res) - ! First derivative of the Exponential Linear Unit (ELU) - ! activation function. class(elu), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) @@ -178,8 +209,29 @@ pure function eval_1d_elu_prime(self, x) result(res) end where end function eval_1d_elu_prime + pure function eval_2d_elu(self, x) result(res) + class(elu), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + where (x >= 0) + res = x + elsewhere + res = self % alpha * (exp(x) - 1) + end where + end function eval_2d_elu + + pure function eval_2d_elu_prime(self, x) result(res) + class(elu), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + where (x >= 0) + res = 1 + elsewhere + res = self % alpha * exp(x) + end where + end function eval_2d_elu_prime + pure function eval_3d_elu(self, x) result(res) - ! Exponential Linear Unit (ELU) activation function. class(elu), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) @@ -191,8 +243,6 @@ pure function eval_3d_elu(self, x) result(res) end function eval_3d_elu pure function eval_3d_elu_prime(self, x) result(res) - ! First derivative of the Exponential Linear Unit (ELU) - ! activation function. class(elu), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) @@ -203,24 +253,30 @@ pure function eval_3d_elu_prime(self, x) result(res) end where end function eval_3d_elu_prime + ! Exponential Activation Functions pure function eval_1d_exponential(self, x) result(res) - ! Exponential activation function. class(exponential), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) res = exp(x) end function eval_1d_exponential + pure function eval_2d_exponential(self, x) result(res) + class(exponential), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = exp(x) + end function eval_2d_exponential + pure function eval_3d_exponential(self, x) result(res) - ! Exponential activation function. class(exponential), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) res = exp(x) end function eval_3d_exponential + ! Gaussian Activation Functions pure function eval_1d_gaussian(self, x) result(res) - ! Gaussian activation function. class(gaussian), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) @@ -228,15 +284,27 @@ pure function eval_1d_gaussian(self, x) result(res) end function eval_1d_gaussian pure function eval_1d_gaussian_prime(self, x) result(res) - ! First derivative of the Gaussian activation function. class(gaussian), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) res = -2 * x * self % eval_1d(x) end function eval_1d_gaussian_prime + pure function eval_2d_gaussian(self, x) result(res) + class(gaussian), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = exp(-x**2) + end function eval_2d_gaussian + + pure function eval_2d_gaussian_prime(self, x) result(res) + class(gaussian), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = -2 * x * self % eval_2d(x) + end function eval_2d_gaussian_prime + pure function eval_3d_gaussian(self, x) result(res) - ! Gaussian activation function. class(gaussian), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) @@ -244,15 +312,14 @@ pure function eval_3d_gaussian(self, x) result(res) end function eval_3d_gaussian pure function eval_3d_gaussian_prime(self, x) result(res) - ! First derivative of the Gaussian activation function. class(gaussian), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) res = -2 * x * self % eval_3d(x) end function eval_3d_gaussian_prime + ! Linear Activation Functions pure function eval_1d_linear(self, x) result(res) - ! Linear activation function. class(linear), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) @@ -260,15 +327,27 @@ pure function eval_1d_linear(self, x) result(res) end function eval_1d_linear pure function eval_1d_linear_prime(self, x) result(res) - ! First derivative of the Linear activation function. class(linear), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) res = 1 end function eval_1d_linear_prime + pure function eval_2d_linear(self, x) result(res) + class(linear), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = x + end function eval_2d_linear + + pure function eval_2d_linear_prime(self, x) result(res) + class(linear), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = 1 + end function eval_2d_linear_prime + pure function eval_3d_linear(self, x) result(res) - ! Linear activation function. class(linear), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) @@ -276,15 +355,14 @@ pure function eval_3d_linear(self, x) result(res) end function eval_3d_linear pure function eval_3d_linear_prime(self, x) result(res) - ! First derivative of the Linear activation function. class(linear), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) res = 1 end function eval_3d_linear_prime + ! ReLU Activation Functions pure function eval_1d_relu(self, x) result(res) - !! Rectified Linear Unit (ReLU) activation function. class(relu), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) @@ -292,15 +370,27 @@ pure function eval_1d_relu(self, x) result(res) end function eval_1d_relu pure function eval_1d_relu_prime(self, x) result(res) - ! First derivative of the Rectified Linear Unit (ReLU) activation function. class(relu), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) res = merge(1., 0., x > 0) end function eval_1d_relu_prime + pure function eval_2d_relu(self, x) result(res) + class(relu), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = max(0., x) + end function eval_2d_relu + + pure function eval_2d_relu_prime(self, x) result(res) + class(relu), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = merge(1., 0., x > 0) + end function eval_2d_relu_prime + pure function eval_3d_relu(self, x) result(res) - !! Rectified Linear Unit (ReLU) activation function. class(relu), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) @@ -308,15 +398,14 @@ pure function eval_3d_relu(self, x) result(res) end function eval_3d_relu pure function eval_3d_relu_prime(self, x) result(res) - ! First derivative of the Rectified Linear Unit (ReLU) activation function. class(relu), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) res = merge(1., 0., x > 0) end function eval_3d_relu_prime + ! Leaky ReLU Activation Functions pure function eval_1d_leaky_relu(self, x) result(res) - !! Leaky Rectified Linear Unit (Leaky ReLU) activation function. class(leaky_relu), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) @@ -324,15 +413,27 @@ pure function eval_1d_leaky_relu(self, x) result(res) end function eval_1d_leaky_relu pure function eval_1d_leaky_relu_prime(self, x) result(res) - ! First derivative of the Leaky Rectified Linear Unit (Leaky ReLU) activation function. class(leaky_relu), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) res = merge(1., self%alpha, x > 0) end function eval_1d_leaky_relu_prime + pure function eval_2d_leaky_relu(self, x) result(res) + class(leaky_relu), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = max(self % alpha * x, x) + end function eval_2d_leaky_relu + + pure function eval_2d_leaky_relu_prime(self, x) result(res) + class(leaky_relu), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = merge(1., self%alpha, x > 0) + end function eval_2d_leaky_relu_prime + pure function eval_3d_leaky_relu(self, x) result(res) - !! Leaky Rectified Linear Unit (Leaky ReLU) activation function. class(leaky_relu), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) @@ -340,47 +441,57 @@ pure function eval_3d_leaky_relu(self, x) result(res) end function eval_3d_leaky_relu pure function eval_3d_leaky_relu_prime(self, x) result(res) - ! First derivative of the Leaky Rectified Linear Unit (Leaky ReLU) activation function. class(leaky_relu), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) res = merge(1., self%alpha, x > 0) end function eval_3d_leaky_relu_prime + ! Sigmoid Activation Functions pure function eval_1d_sigmoid(self, x) result(res) - ! Sigmoid activation function. class(sigmoid), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) res = 1 / (1 + exp(-x)) - endfunction eval_1d_sigmoid + end function eval_1d_sigmoid pure function eval_1d_sigmoid_prime(self, x) result(res) - ! First derivative of the sigmoid activation function. class(sigmoid), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) res = self % eval_1d(x) * (1 - self % eval_1d(x)) end function eval_1d_sigmoid_prime + pure function eval_2d_sigmoid(self, x) result(res) + class(sigmoid), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = 1 / (1 + exp(-x)) + end function eval_2d_sigmoid + + pure function eval_2d_sigmoid_prime(self, x) result(res) + class(sigmoid), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = self % eval_2d(x) * (1 - self % eval_2d(x)) + end function eval_2d_sigmoid_prime + pure function eval_3d_sigmoid(self, x) result(res) - ! Sigmoid activation function. class(sigmoid), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) res = 1 / (1 + exp(-x)) - endfunction eval_3d_sigmoid + end function eval_3d_sigmoid pure function eval_3d_sigmoid_prime(self, x) result(res) - ! First derivative of the sigmoid activation function. class(sigmoid), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) res = self % eval_3d(x) * (1 - self % eval_3d(x)) end function eval_3d_sigmoid_prime + ! Softmax Activation Functions pure function eval_1d_softmax(self, x) result(res) - !! Softmax activation function class(softmax), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) @@ -389,15 +500,28 @@ pure function eval_1d_softmax(self, x) result(res) end function eval_1d_softmax pure function eval_1d_softmax_prime(self, x) result(res) - !! Derivative of the softmax activation function. class(softmax), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) res = self%eval_1d(x) * (1 - self%eval_1d(x)) end function eval_1d_softmax_prime + pure function eval_2d_softmax(self, x) result(res) + class(softmax), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = exp(x - maxval(x)) + res = res / sum(res) + end function eval_2d_softmax + + pure function eval_2d_softmax_prime(self, x) result(res) + class(softmax), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = self % eval_2d(x) * (1 - self % eval_2d(x)) + end function eval_2d_softmax_prime + pure function eval_3d_softmax(self, x) result(res) - !! Softmax activation function class(softmax), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) @@ -406,15 +530,14 @@ pure function eval_3d_softmax(self, x) result(res) end function eval_3d_softmax pure function eval_3d_softmax_prime(self, x) result(res) - !! Derivative of the softmax activation function. class(softmax), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) res = self % eval_3d(x) * (1 - self % eval_3d(x)) end function eval_3d_softmax_prime + ! Softplus Activation Functions pure function eval_1d_softplus(self, x) result(res) - ! Softplus activation function. class(softplus), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) @@ -422,15 +545,27 @@ pure function eval_1d_softplus(self, x) result(res) end function eval_1d_softplus pure function eval_1d_softplus_prime(self, x) result(res) - class(softplus), intent(in) :: self - ! First derivative of the softplus activation function. + class(softplus), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) res = exp(x) / (exp(x) + 1) end function eval_1d_softplus_prime + pure function eval_2d_softplus(self, x) result(res) + class(softplus), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = log(exp(x) + 1) + end function eval_2d_softplus + + pure function eval_2d_softplus_prime(self, x) result(res) + class(softplus), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = exp(x) / (exp(x) + 1) + end function eval_2d_softplus_prime + pure function eval_3d_softplus(self, x) result(res) - ! Softplus activation function. class(softplus), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) @@ -438,15 +573,14 @@ pure function eval_3d_softplus(self, x) result(res) end function eval_3d_softplus pure function eval_3d_softplus_prime(self, x) result(res) - class(softplus), intent(in) :: self - ! First derivative of the softplus activation function. + class(softplus), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) res = exp(x) / (exp(x) + 1) end function eval_3d_softplus_prime + ! Step Activation Functions pure function eval_1d_step(self, x) result(res) - ! Step activation function. class(step), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) @@ -454,15 +588,27 @@ pure function eval_1d_step(self, x) result(res) end function eval_1d_step pure function eval_1d_step_prime(self, x) result(res) - ! First derivative of the step activation function. class(step), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) res = 0 end function eval_1d_step_prime + pure function eval_2d_step(self, x) result(res) + class(step), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = merge(1., 0., x > 0) + end function eval_2d_step + + pure function eval_2d_step_prime(self, x) result(res) + class(step), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = 0 + end function eval_2d_step_prime + pure function eval_3d_step(self, x) result(res) - ! Step activation function. class(step), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) @@ -470,15 +616,14 @@ pure function eval_3d_step(self, x) result(res) end function eval_3d_step pure function eval_3d_step_prime(self, x) result(res) - ! First derivative of the step activation function. class(step), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) res = 0 end function eval_3d_step_prime + ! Tanh Activation Functions pure function eval_1d_tanh(self, x) result(res) - ! Tangent hyperbolic activation function. class(tanhf), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) @@ -486,15 +631,27 @@ pure function eval_1d_tanh(self, x) result(res) end function eval_1d_tanh pure function eval_1d_tanh_prime(self, x) result(res) - ! First derivative of the tanh activation function. class(tanhf), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) res = 1 - tanh(x)**2 end function eval_1d_tanh_prime + pure function eval_2d_tanh(self, x) result(res) + class(tanhf), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = tanh(x) + end function eval_2d_tanh + + pure function eval_2d_tanh_prime(self, x) result(res) + class(tanhf), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + res = 1 - tanh(x)**2 + end function eval_2d_tanh_prime + pure function eval_3d_tanh(self, x) result(res) - ! Tangent hyperbolic activation function. class(tanhf), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) @@ -502,15 +659,14 @@ pure function eval_3d_tanh(self, x) result(res) end function eval_3d_tanh pure function eval_3d_tanh_prime(self, x) result(res) - ! First derivative of the tanh activation function. class(tanhf), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) res = 1 - tanh(x)**2 end function eval_3d_tanh_prime + ! CELU Activation Functions pure function eval_1d_celu(self, x) result(res) - ! Celu activation function. class(celu), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) @@ -519,10 +675,9 @@ pure function eval_1d_celu(self, x) result(res) else where res = self % alpha * (exp(x / self % alpha) - 1.0) end where - end function + end function eval_1d_celu pure function eval_1d_celu_prime(self, x) result(res) - ! Celu activation function. class(celu), intent(in) :: self real, intent(in) :: x(:) real :: res(size(x)) @@ -531,10 +686,31 @@ pure function eval_1d_celu_prime(self, x) result(res) else where res = exp(x / self % alpha) end where - end function + end function eval_1d_celu_prime + + pure function eval_2d_celu(self, x) result(res) + class(celu), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + where (x >= 0.0) + res = x + else where + res = self % alpha * (exp(x / self % alpha) - 1.0) + end where + end function eval_2d_celu + + pure function eval_2d_celu_prime(self, x) result(res) + class(celu), intent(in) :: self + real, intent(in) :: x(:,:) + real :: res(size(x,1),size(x,2)) + where (x >= 0.0) + res = 1.0 + else where + res = exp(x / self % alpha) + end where + end function eval_2d_celu_prime pure function eval_3d_celu(self, x) result(res) - ! Celu activation function. class(celu), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) @@ -543,10 +719,9 @@ pure function eval_3d_celu(self, x) result(res) else where res = self % alpha * (exp(x / self % alpha) - 1.0) end where - end function + end function eval_3d_celu pure function eval_3d_celu_prime(self, x) result(res) - ! Celu activation function. class(celu), intent(in) :: self real, intent(in) :: x(:,:,:) real :: res(size(x,1),size(x,2),size(x,3)) @@ -555,13 +730,10 @@ pure function eval_3d_celu_prime(self, x) result(res) else where res = exp(x / self % alpha) end where - end function + end function eval_3d_celu_prime + ! Utility Functions function get_activation_by_name(activation_name) result(res) - ! Workaround to get activation_function with some - ! hardcoded default parameters by its name. - ! Need this function since we get only activation name - ! from keras files. character(len=*), intent(in) :: activation_name class(activation_function), allocatable :: res @@ -611,16 +783,8 @@ function get_activation_by_name(activation_name) result(res) end function get_activation_by_name pure function get_name(self) result(name) - !! Return the name of the activation function. - !! - !! Normally we would place this in the definition of each type, however - !! accessing the name variable directly from the type would require type - !! guards just like we have here. This at least keeps all the type guards - !! in one place. class(activation_function), intent(in) :: self - !! The activation function instance. character(:), allocatable :: name - !! The name of the activation function. select type (self) class is (elu) name = 'elu' @@ -651,4 +815,4 @@ pure function get_name(self) result(name) end select end function get_name -end module nf_activation +end module nf_activation \ No newline at end of file diff --git a/src/nf/nf_datasets_mnist_submodule.f90 b/src/nf/nf_datasets_mnist_submodule.f90 index 842cafe1..a0bed0a8 100644 --- a/src/nf/nf_datasets_mnist_submodule.f90 +++ b/src/nf/nf_datasets_mnist_submodule.f90 @@ -50,9 +50,9 @@ module subroutine load_mnist(training_images, training_labels, & real, allocatable, intent(in out), optional :: testing_labels(:) integer, parameter :: dtype = 4, image_size = 784 - integer, parameter :: num_training_images = 50000 - integer, parameter :: num_validation_images = 10000 - integer, parameter :: num_testing_images = 10000 + integer, parameter :: num_training_images = 500 + integer, parameter :: num_validation_images = 100 + integer, parameter :: num_testing_images = 100 logical :: file_exists ! Check if MNIST data is present and download it if not. diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index ea1c08df..994b0a56 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -8,7 +8,7 @@ module nf_layer_constructors implicit none private - public :: conv2d, dense, flatten, input, maxpool2d, reshape + public :: conv2d, dense, flatten, input, locally_connected_1d, maxpool1d, maxpool2d, reshape, reshape2d interface input @@ -152,6 +152,56 @@ module function conv2d(filters, kernel_size, activation) result(res) !! Resulting layer instance end function conv2d + module function locally_connected_1d(filters, kernel_size, activation) result(res) + !! CHANGE THE COMMENTS!!! + !! 2-d convolutional layer constructor. + !! + !! This layer is for building 2-d convolutional network. + !! Although the established convention is to call these layers 2-d, + !! the shape of the data is actuall 3-d: image width, image height, + !! and the number of channels. + !! A conv2d layer must not be the first layer in the network. + !! + !! Example: + !! + !! ``` + !! use nf, only :: conv2d, layer + !! type(layer) :: conv2d_layer + !! conv2d_layer = dense(filters=32, kernel_size=3) + !! conv2d_layer = dense(filters=32, kernel_size=3, activation='relu') + !! ``` + integer, intent(in) :: filters + !! Number of filters in the output of the layer + integer, intent(in) :: kernel_size + !! Width of the convolution window, commonly 3 or 5 + class(activation_function), intent(in), optional :: activation + !! Activation function (default sigmoid) + type(layer) :: res + !! Resulting layer instance + end function locally_connected_1d + + module function maxpool1d(pool_size, stride) result(res) + !! 2-d maxpooling layer constructor. + !! + !! This layer is for downscaling other layers, typically `conv2d`. + !! + !! Example: + !! + !! ``` + !! use nf, only :: maxpool2d, layer + !! type(layer) :: maxpool2d_layer + !! maxpool2d_layer = maxpool2d(pool_size=2) + !! maxpool2d_layer = maxpool2d(pool_size=2, stride=3) + !! ``` + integer, intent(in) :: pool_size + !! Width of the pooling window, commonly 2 + integer, intent(in), optional :: stride + !! Stride of the pooling window, commonly equal to `pool_size`; + !! Defaults to `pool_size` if omitted. + type(layer) :: res + !! Resulting layer instance + end function maxpool1d + module function maxpool2d(pool_size, stride) result(res) !! 2-d maxpooling layer constructor. !! @@ -185,6 +235,17 @@ module function reshape(output_shape) result(res) !! Resulting layer instance end function reshape + module function reshape2d(output_shape) result(res) + !! Rank-1 to rank-any reshape layer constructor. + !! Currently implemented is only rank-3 for the output of the reshape. + !! + !! This layer is for connecting 1-d inputs to conv2d or similar layers. + integer, intent(in) :: output_shape(:) + !! Shape of the output + type(layer) :: res + !! Resulting layer instance + end function reshape2d + end interface end module nf_layer_constructors diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index 4c5994ee..5982ebac 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -7,8 +7,11 @@ use nf_input1d_layer, only: input1d_layer use nf_input2d_layer, only: input2d_layer use nf_input3d_layer, only: input3d_layer + use nf_locally_connected_1d_layer, only: locally_connected_1d_layer + use nf_maxpool1d_layer, only: maxpool1d_layer use nf_maxpool2d_layer, only: maxpool2d_layer use nf_reshape_layer, only: reshape3d_layer + use nf_reshape2d_layer, only: reshape2d_layer use nf_activation, only: activation_function, relu, sigmoid implicit none @@ -40,6 +43,31 @@ module function conv2d(filters, kernel_size, activation) result(res) end function conv2d + module function locally_connected_1d(filters, kernel_size, activation) result(res) + integer, intent(in) :: filters + integer, intent(in) :: kernel_size + class(activation_function), intent(in), optional :: activation + type(layer) :: res + + class(activation_function), allocatable :: activation_tmp + + res % name = 'locally_connected_1d' + + if (present(activation)) then + allocate(activation_tmp, source=activation) + else + allocate(activation_tmp, source=relu()) + end if + + res % activation = activation_tmp % get_name() + + allocate( & + res % p, & + source=locally_connected_1d_layer(filters, kernel_size, activation_tmp) & + ) + + end function locally_connected_1d + module function dense(layer_size, activation) result(res) integer, intent(in) :: layer_size @@ -103,6 +131,33 @@ module function input3d(dim1, dim2, dim3) result(res) res % initialized = .true. end function input3d + module function maxpool1d(pool_size, stride) result(res) + integer, intent(in) :: pool_size + integer, intent(in), optional :: stride + integer :: stride_ + type(layer) :: res + + if (pool_size < 2) & + error stop 'pool_size must be >= 2 in a maxpool1d layer' + + ! Stride defaults to pool_size if not provided + if (present(stride)) then + stride_ = stride + else + stride_ = pool_size + end if + + if (stride_ < 1) & + error stop 'stride must be >= 1 in a maxpool1d layer' + + res % name = 'maxpool1d' + + allocate( & + res % p, & + source=maxpool1d_layer(pool_size, stride_) & + ) + + end function maxpool1d module function maxpool2d(pool_size, stride) result(res) integer, intent(in) :: pool_size @@ -148,4 +203,19 @@ module function reshape(output_shape) result(res) end function reshape + module function reshape2d(output_shape) result(res) + integer, intent(in) :: output_shape(:) + type(layer) :: res + + res % name = 'reshape2d' + res % layer_shape = output_shape + + if (size(output_shape) == 2) then + allocate(res % p, source=reshape2d_layer(output_shape)) + else + error stop 'size(output_shape) of the reshape layer must == 2' + end if + + end function reshape2d + end submodule nf_layer_constructors_submodule diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index 41b9a2ce..60478acf 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -7,7 +7,10 @@ use nf_input1d_layer, only: input1d_layer use nf_input2d_layer, only: input2d_layer use nf_input3d_layer, only: input3d_layer + use nf_locally_connected_1d_layer, only: locally_connected_1d_layer + use nf_maxpool1d_layer, only: maxpool1d_layer use nf_maxpool2d_layer, only: maxpool2d_layer + use nf_reshape2d_layer, only: reshape2d_layer use nf_reshape_layer, only: reshape3d_layer use nf_optimizers, only: optimizer_base_type @@ -60,7 +63,33 @@ pure module subroutine backward_2d(self, previous, gradient) ! Backward pass from a 2-d layer downstream currently implemented ! only for dense and flatten layers - ! CURRENTLY NO LAYERS, tbd: pull/197 and pull/199 + + select type(this_layer => self % p) + + type is(locally_connected_1d_layer) + + select type(prev_layer => previous % p) + type is(maxpool1d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(reshape2d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(locally_connected_1d_layer) + call this_layer % backward(prev_layer % output, gradient) + end select + + type is(maxpool1d_layer) + + select type(prev_layer => previous % p) + type is(maxpool1d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(reshape2d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(locally_connected_1d_layer) + call this_layer % backward(prev_layer % output, gradient) + end select + + end select + end subroutine backward_2d @@ -151,6 +180,34 @@ pure module subroutine forward(self, input) type is(reshape3d_layer) call this_layer % forward(prev_layer % output) end select + + type is(locally_connected_1d_layer) + + ! Upstream layers permitted: input2d, locally_connected_1d, maxpool1d, reshape2d + select type(prev_layer => input % p) + type is(input2d_layer) + call this_layer % forward(prev_layer % output) + type is(locally_connected_1d_layer) + call this_layer % forward(prev_layer % output) + type is(maxpool1d_layer) + call this_layer % forward(prev_layer % output) + type is(reshape2d_layer) + call this_layer % forward(prev_layer % output) + end select + + type is(maxpool1d_layer) + + ! Upstream layers permitted: input1d, locally_connected_1d, maxpool1d, reshape2d + select type(prev_layer => input % p) + type is(input2d_layer) + call this_layer % forward(prev_layer % output) + type is(locally_connected_1d_layer) + call this_layer % forward(prev_layer % output) + type is(maxpool1d_layer) + call this_layer % forward(prev_layer % output) + type is(reshape2d_layer) + call this_layer % forward(prev_layer % output) + end select type is(maxpool2d_layer) @@ -227,6 +284,12 @@ pure module subroutine get_output_2d(self, output) type is(input2d_layer) allocate(output, source=this_layer % output) + type is(maxpool1d_layer) + allocate(output, source=this_layer % output) + type is(locally_connected_1d_layer) + allocate(output, source=this_layer % output) + type is(reshape2d_layer) + allocate(output, source=this_layer % output) class default error stop '1-d output can only be read from an input1d, dense, or flatten layer.' @@ -318,10 +381,16 @@ elemental module function get_num_params(self) result(num_params) num_params = this_layer % get_num_params() type is (conv2d_layer) num_params = this_layer % get_num_params() + type is (locally_connected_1d_layer) + num_params = this_layer % get_num_params() + type is (maxpool1d_layer) + num_params = 0 type is (maxpool2d_layer) num_params = 0 type is (flatten_layer) num_params = 0 + type is (reshape2d_layer) + num_params = 0 type is (reshape3d_layer) num_params = 0 class default @@ -345,10 +414,16 @@ module function get_params(self) result(params) params = this_layer % get_params() type is (conv2d_layer) params = this_layer % get_params() + type is (locally_connected_1d_layer) + params = this_layer % get_params() + type is (maxpool1d_layer) + ! No parameters to get. type is (maxpool2d_layer) ! No parameters to get. type is (flatten_layer) ! No parameters to get. + type is (reshape2d_layer) + ! No parameters to get. type is (reshape3d_layer) ! No parameters to get. class default @@ -372,10 +447,16 @@ module function get_gradients(self) result(gradients) gradients = this_layer % get_gradients() type is (conv2d_layer) gradients = this_layer % get_gradients() + type is (locally_connected_1d_layer) + gradients = this_layer % get_gradients() + type is (maxpool1d_layer) + ! No gradients to get. type is (maxpool2d_layer) ! No gradients to get. type is (flatten_layer) ! No gradients to get. + type is (reshape2d_layer) + ! No gradients to get. type is (reshape3d_layer) ! No gradients to get. class default @@ -424,6 +505,14 @@ module subroutine set_params(self, params) type is (conv2d_layer) call this_layer % set_params(params) + + type is (locally_connected_1d_layer) + call this_layer % set_params(params) + + type is (maxpool1d_layer) + ! No parameters to set. + write(stderr, '(a)') 'Warning: calling set_params() ' & + // 'on a zero-parameter layer; nothing to do.' type is (maxpool2d_layer) ! No parameters to set. @@ -434,6 +523,11 @@ module subroutine set_params(self, params) ! No parameters to set. write(stderr, '(a)') 'Warning: calling set_params() ' & // 'on a zero-parameter layer; nothing to do.' + + type is (reshape2d_layer) + ! No parameters to set. + write(stderr, '(a)') 'Warning: calling set_params() ' & + // 'on a zero-parameter layer; nothing to do.' type is (reshape3d_layer) ! No parameters to set. diff --git a/src/nf/nf_locally_connected_1d.f90 b/src/nf/nf_locally_connected_1d.f90 new file mode 100644 index 00000000..739df749 --- /dev/null +++ b/src/nf/nf_locally_connected_1d.f90 @@ -0,0 +1,119 @@ +module nf_locally_connected_1d_layer + !! This modules provides a 1-d convolutional `locally_connected_1d` type. + + use nf_activation, only: activation_function + use nf_base_layer, only: base_layer + implicit none + + private + public :: locally_connected_1d_layer + + type, extends(base_layer) :: locally_connected_1d_layer + + integer :: width + integer :: height + integer :: channels + integer :: kernel_size + integer :: filters + + real, allocatable :: biases(:) ! size(filters) + real, allocatable :: kernel(:,:,:) ! filters x channels x window x window + real, allocatable :: output(:,:) ! filters x output_width * output_height + real, allocatable :: z(:,:) ! kernel .dot. input + bias + + real, allocatable :: dw(:,:,:) ! weight (kernel) gradients + real, allocatable :: db(:) ! bias gradients + real, allocatable :: gradient(:,:) + + class(activation_function), allocatable :: activation + + contains + + procedure :: forward + procedure :: backward + procedure :: get_gradients + procedure :: get_num_params + procedure :: get_params + procedure :: init + procedure :: set_params + + end type locally_connected_1d_layer + + interface locally_connected_1d_layer + module function locally_connected_1d_layer_cons(filters, kernel_size, activation) & + result(res) + !! `locally_connected_1d_layer` constructor function + integer, intent(in) :: filters + integer, intent(in) :: kernel_size + class(activation_function), intent(in) :: activation + type(locally_connected_1d_layer) :: res + end function locally_connected_1d_layer_cons + end interface locally_connected_1d_layer + + interface + + module subroutine init(self, input_shape) + !! Initialize the layer data structures. + !! + !! This is a deferred procedure from the `base_layer` abstract type. + class(locally_connected_1d_layer), intent(in out) :: self + !! A `locally_connected_1d_layer` instance + integer, intent(in) :: input_shape(:) + !! Input layer dimensions + end subroutine init + + pure module subroutine forward(self, input) + !! Apply a forward pass on the `locally_connected_1d` layer. + class(locally_connected_1d_layer), intent(in out) :: self + !! A `locally_connected_1d_layer` instance + real, intent(in) :: input(:,:) + !! Input data + end subroutine forward + + pure module subroutine backward(self, input, gradient) + !! Apply a backward pass on the `locally_connected_1d` layer. + class(locally_connected_1d_layer), intent(in out) :: self + !! A `locally_connected_1d_layer` instance + real, intent(in) :: input(:,:) + !! Input data (previous layer) + real, intent(in) :: gradient(:,:) + !! Gradient (next layer) + end subroutine backward + + pure module function get_num_params(self) result(num_params) + !! Get the number of parameters in the layer. + class(locally_connected_1d_layer), intent(in) :: self + !! A `locally_connected_1d_layer` instance + integer :: num_params + !! Number of parameters + end function get_num_params + + module function get_params(self) result(params) + !! Return the parameters (weights and biases) of this layer. + !! The parameters are ordered as weights first, biases second. + class(locally_connected_1d_layer), intent(in), target :: self + !! A `locally_connected_1d_layer` instance + real, allocatable :: params(:) + !! Parameters to get + end function get_params + + module function get_gradients(self) result(gradients) + !! Return the gradients of this layer. + !! The gradients are ordered as weights first, biases second. + class(locally_connected_1d_layer), intent(in), target :: self + !! A `locally_connected_1d_layer` instance + real, allocatable :: gradients(:) + !! Gradients to get + end function get_gradients + + module subroutine set_params(self, params) + !! Set the parameters of the layer. + class(locally_connected_1d_layer), intent(in out) :: self + !! A `locally_connected_1d_layer` instance + real, intent(in) :: params(:) + !! Parameters to set + end subroutine set_params + + end interface + +end module nf_locally_connected_1d_layer diff --git a/src/nf/nf_locally_connected_1d_submodule.f90 b/src/nf/nf_locally_connected_1d_submodule.f90 new file mode 100644 index 00000000..e3715dd6 --- /dev/null +++ b/src/nf/nf_locally_connected_1d_submodule.f90 @@ -0,0 +1,211 @@ +submodule(nf_locally_connected_1d_layer) nf_locally_connected_1d_layer_submodule + + use nf_activation, only: activation_function + use nf_random, only: random_normal + + implicit none + +contains + + module function locally_connected_1d_layer_cons(filters, kernel_size, activation) result(res) + implicit none + integer, intent(in) :: filters + integer, intent(in) :: kernel_size + class(activation_function), intent(in) :: activation + type(locally_connected_1d_layer) :: res + + res % kernel_size = kernel_size + res % filters = filters + res % activation_name = activation % get_name() + allocate( res % activation, source = activation ) + + end function locally_connected_1d_layer_cons + + module subroutine init(self, input_shape) + implicit none + class(locally_connected_1d_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + self % channels = input_shape(1) + self % width = input_shape(2) - self % kernel_size + 1 + + ! Output of shape filters x width + allocate(self % output(self % filters, self % width)) + self % output = 0 + + ! Kernel of shape filters x channels x kernel_size + allocate(self % kernel(self % filters, self % channels, self % kernel_size)) + + ! Initialize the kernel with random values with a normal distribution + call random_normal(self % kernel) + self % kernel = self % kernel / self % kernel_size ** 2 + + allocate(self % biases(self % filters)) + self % biases = 0 + + allocate(self % z, mold=self % output) + self % z = 0 + + allocate(self % gradient(input_shape(1), input_shape(2))) + self % gradient = 0 + + allocate(self % dw, mold=self % kernel) + self % dw = 0 + + allocate(self % db, mold=self % biases) + self % db = 0 + + end subroutine init + + pure module subroutine forward(self, input) + implicit none + class(locally_connected_1d_layer), intent(in out) :: self + real, intent(in) :: input(:,:) + integer :: input_width, input_channels + integer :: i, n, i_out + integer :: iws, iwe + integer :: half_window + + ! Get input dimensions + input_channels = size(input, dim=1) + input_width = size(input, dim=2) + + ! For a kernel of odd size, half_window = kernel_size / 2 (integer division) + half_window = self % kernel_size / 2 + + ! Loop over output indices rather than input indices. + do i_out = 1, self % width + ! Compute the corresponding center index in the input. + i = i_out + half_window + + ! Define the window in the input corresponding to the filter kernel + iws = i - half_window + iwe = i + half_window + + ! Compute the inner tensor product (sum of element-wise products) + ! for each filter across all channels and positions in the kernel. + do concurrent(n = 1:self % filters) + self % z(n, i_out) = sum(self % kernel(n, :, :) * input(:, iws:iwe)) + end do + + ! Add the bias for each filter. + self % z(:, i_out) = self % z(:, i_out) + self % biases + end do + + ! Apply the activation function to get the final output. + self % output = self % activation % eval(self % z) + end subroutine forward + + + pure module subroutine backward(self, input, gradient) + implicit none + class(locally_connected_1d_layer), intent(in out) :: self + real, intent(in) :: input(:,:) ! shape: (channels, width) + real, intent(in) :: gradient(:,:) ! shape: (filters, width) + + ! Local gradient arrays: + real :: db(self % filters) + real :: dw(self % filters, self % channels, self % kernel_size) + real :: gdz(self % filters, size(input, 2)) + + integer :: i, n, k + integer :: input_channels, input_width + integer :: istart, iend + integer :: iws, iwe + integer :: half_window + + ! Get input dimensions. + input_channels = size(input, dim=1) + input_width = size(input, dim=2) + + ! For an odd-sized kernel, half_window = kernel_size / 2. + half_window = self % kernel_size / 2 + + ! Define the valid output range so that the full input window is available. + istart = half_window + 1 + iend = input_width - half_window + + !--------------------------------------------------------------------- + ! Compute the local gradient: gdz = (dL/dy) * sigma'(z) + ! We assume self%z stores the pre-activation values from the forward pass. + gdz = 0.0 + gdz(:, istart:iend) = gradient(:, istart:iend) * self % activation % eval_prime(self % z(:, istart:iend)) + + !--------------------------------------------------------------------- + ! Compute gradient with respect to biases: + ! dL/db(n) = sum_{i in valid range} gdz(n, i) + do concurrent (n = 1:self % filters) + db(n) = sum(gdz(n, istart:iend)) + end do + + ! Initialize weight gradient and input gradient accumulators. + dw = 0.0 + self % gradient = 0.0 ! This array is assumed preallocated to shape (channels, width) + + !--------------------------------------------------------------------- + ! Accumulate gradients over valid output positions. + ! For each output position i, determine the corresponding input window indices. + do concurrent (n = 1:self % filters, & + k = 1:self % channels, & + i = istart:iend) + ! The input window corresponding to output index i: + iws = i - half_window + iwe = i + half_window + + ! Weight gradient (dL/dw): + ! For each kernel element, the contribution is the product of the input in the window + ! and the local gradient at the output position i. + dw(n, k, :) = dw(n, k, :) + input(k, iws:iwe) * gdz(n, i) + + ! Input gradient (dL/dx): + ! Distribute the effect of the output gradient back onto the input window, + ! weighted by the kernel weights. + self % gradient(k, iws:iwe) = self % gradient(k, iws:iwe) + self % kernel(n, k, :) * gdz(n, i) + end do + + !--------------------------------------------------------------------- + ! Accumulate the computed gradients into the layer's stored gradients. + self % dw = self % dw + dw + self % db = self % db + db + + end subroutine backward + + pure module function get_num_params(self) result(num_params) + class(locally_connected_1d_layer), intent(in) :: self + integer :: num_params + num_params = product(shape(self % kernel)) + size(self % biases) + end function get_num_params + + module function get_params(self) result(params) + class(locally_connected_1d_layer), intent(in), target :: self + real, allocatable :: params(:) + real, pointer :: w_(:) => null() + w_(1:size(self % kernel)) => self % kernel + params = [ w_, self % biases ] + end function get_params + + module function get_gradients(self) result(gradients) + class(locally_connected_1d_layer), intent(in), target :: self + real, allocatable :: gradients(:) + real, pointer :: dw_(:) => null() + dw_(1:size(self % dw)) => self % dw + gradients = [ dw_, self % db ] + end function get_gradients + + module subroutine set_params(self, params) + class(locally_connected_1d_layer), intent(in out) :: self + real, intent(in) :: params(:) + + if (size(params) /= self % get_num_params()) then + error stop 'locally_connected_1d % set_params: Number of parameters does not match' + end if + + self % kernel = reshape(params(:product(shape(self % kernel))), shape(self % kernel)) + + associate(n => product(shape(self % kernel))) + self % biases = params(n + 1 : n + self % filters) + end associate + + end subroutine set_params + +end submodule nf_locally_connected_1d_layer_submodule diff --git a/src/nf/nf_maxpool1d_layer.f90 b/src/nf/nf_maxpool1d_layer.f90 new file mode 100644 index 00000000..b9a14d07 --- /dev/null +++ b/src/nf/nf_maxpool1d_layer.f90 @@ -0,0 +1,69 @@ +module nf_maxpool1d_layer + !! This module provides the 1-d maxpooling layer. + + use nf_base_layer, only: base_layer + implicit none + + private + public :: maxpool1d_layer + + type, extends(base_layer) :: maxpool1d_layer + integer :: channels + integer :: width ! Length of the input along the pooling dimension + integer :: pool_size + integer :: stride + + ! Location (as input matrix indices) of the maximum value within each pooling region. + ! Dimensions: (channels, new_width) + integer, allocatable :: maxloc(:,:) + + ! Gradient for the input (same shape as the input). + real, allocatable :: gradient(:,:) + ! Output after pooling (dimensions: (channels, new_width)). + real, allocatable :: output(:,:) + contains + procedure :: init + procedure :: forward + procedure :: backward + end type maxpool1d_layer + + interface maxpool1d_layer + pure module function maxpool1d_layer_cons(pool_size, stride) result(res) + !! `maxpool1d` constructor function. + integer, intent(in) :: pool_size + !! Width of the pooling window. + integer, intent(in) :: stride + !! Stride of the pooling window. + type(maxpool1d_layer) :: res + end function maxpool1d_layer_cons + end interface maxpool1d_layer + + interface + module subroutine init(self, input_shape) + !! Initialize the `maxpool1d` layer instance with an input shape. + class(maxpool1d_layer), intent(in out) :: self + !! `maxpool1d_layer` instance. + integer, intent(in) :: input_shape(:) + !! Array shape of the input layer, expected as (channels, width). + end subroutine init + + pure module subroutine forward(self, input) + !! Run a forward pass of the `maxpool1d` layer. + class(maxpool1d_layer), intent(in out) :: self + !! `maxpool1d_layer` instance. + real, intent(in) :: input(:,:) + !! Input data (output of the previous layer), with shape (channels, width). + end subroutine forward + + pure module subroutine backward(self, input, gradient) + !! Run a backward pass of the `maxpool1d` layer. + class(maxpool1d_layer), intent(in out) :: self + !! `maxpool1d_layer` instance. + real, intent(in) :: input(:,:) + !! Input data (output of the previous layer). + real, intent(in) :: gradient(:,:) + !! Gradient from the downstream layer, with shape (channels, pooled width). + end subroutine backward + end interface + +end module nf_maxpool1d_layer \ No newline at end of file diff --git a/src/nf/nf_maxpool1d_layer_submodule.f90 b/src/nf/nf_maxpool1d_layer_submodule.f90 new file mode 100644 index 00000000..9a0b081d --- /dev/null +++ b/src/nf/nf_maxpool1d_layer_submodule.f90 @@ -0,0 +1,93 @@ +submodule(nf_maxpool1d_layer) nf_maxpool1d_layer_submodule + implicit none + +contains + + pure module function maxpool1d_layer_cons(pool_size, stride) result(res) + implicit none + integer, intent(in) :: pool_size + integer, intent(in) :: stride + type(maxpool1d_layer) :: res + + res % pool_size = pool_size + res % stride = stride + end function maxpool1d_layer_cons + + + module subroutine init(self, input_shape) + implicit none + class(maxpool1d_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + ! input_shape is expected to be (channels, width) + + self % channels = input_shape(1) + ! The new width is the integer division of the input width by the stride. + self % width = input_shape(2) / self % stride + + ! Allocate storage for the index of the maximum element within each pooling region. + allocate(self % maxloc(self % channels, self % width)) + self % maxloc = 0 + + ! Allocate the gradient array corresponding to the input dimensions. + allocate(self % gradient(input_shape(1), input_shape(2))) + self % gradient = 0 + + ! Allocate the output array (after pooling). + allocate(self % output(self % channels, self % width)) + self % output = 0 + end subroutine init + + + pure module subroutine forward(self, input) + implicit none + class(maxpool1d_layer), intent(in out) :: self + real, intent(in) :: input(:,:) + integer :: input_width + integer :: i, n + integer :: ii, iend + integer :: iextent + integer :: max_index ! Temporary variable to hold the local index of the max + integer :: maxloc_temp(1) ! Temporary array to hold the result of maxloc + + input_width = size(input, dim=2) + ! Ensure we only process complete pooling regions. + iextent = input_width - mod(input_width, self % stride) + + ! Loop over the input with a step size equal to the stride and over all channels. + do concurrent (i = 1:iextent: self % stride, n = 1:self % channels) + ! Compute the index in the pooled (output) array. + ii = (i - 1) / self % stride + 1 + ! Determine the ending index of the current pooling region. + iend = min(i + self % pool_size - 1, input_width) + + ! Find the index (within the pooling window) of the maximum value. + maxloc_temp = maxloc(input(n, i:iend)) + max_index = maxloc_temp(1) + i - 1 ! Adjust to the index in the original input + + ! Store the location of the maximum value. + self % maxloc(n, ii) = max_index + ! Set the output as the maximum value from this pooling region. + self % output(n, ii) = input(n, max_index) + end do + end subroutine forward + + + pure module subroutine backward(self, input, gradient) + implicit none + class(maxpool1d_layer), intent(in out) :: self + real, intent(in) :: input(:,:) + real, intent(in) :: gradient(:,:) + integer :: channels, pooled_width + integer :: i, n + + channels = size(gradient, dim=1) + pooled_width = size(gradient, dim=2) + + ! The gradient for max-pooling is nonzero only at the input locations + ! that were the maxima during the forward pass. + do concurrent (n = 1:channels, i = 1:pooled_width) + self % gradient(n, self % maxloc(n, i)) = gradient(n, i) + end do + end subroutine backward + +end submodule nf_maxpool1d_layer_submodule \ No newline at end of file diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 506c3295..0e06287f 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -6,10 +6,13 @@ use nf_input1d_layer, only: input1d_layer use nf_input2d_layer, only: input2d_layer use nf_input3d_layer, only: input3d_layer + use nf_locally_connected_1d_layer, only: locally_connected_1d_layer + use nf_maxpool1d_layer, only: maxpool1d_layer use nf_maxpool2d_layer, only: maxpool2d_layer + use nf_reshape2d_layer, only: reshape2d_layer use nf_reshape_layer, only: reshape3d_layer use nf_layer, only: layer - use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape + use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool1d, maxpool2d, reshape, reshape2d use nf_loss, only: quadratic use nf_optimizers, only: optimizer_base_type, sgd use nf_parallel, only: tile_indices @@ -76,6 +79,12 @@ module function network_from_layers(layers) result(res) type is(reshape3d_layer) res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] n = n + 1 + type is(maxpool1d_layer) + res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] + n = n + 1 + type is(reshape2d_layer) + res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] + n = n + 1 class default n = n + 1 end select @@ -143,6 +152,10 @@ module subroutine backward(self, output, loss) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) type is(reshape3d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(maxpool1d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(reshape2d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) end select end if diff --git a/src/nf/nf_reshape2d_layer.f90 b/src/nf/nf_reshape2d_layer.f90 new file mode 100644 index 00000000..6b99729b --- /dev/null +++ b/src/nf/nf_reshape2d_layer.f90 @@ -0,0 +1,77 @@ +module nf_reshape2d_layer + + !! This module provides the concrete reshape layer type. + !! It is used internally by the layer type. + !! It is not intended to be used directly by the user. + + use nf_base_layer, only: base_layer + + implicit none + + private + public :: reshape2d_layer + + type, extends(base_layer) :: reshape2d_layer + + !! Concrete implementation of a reshape layer type + !! It implements only rank-1 to rank-2 reshaping. + + integer :: input_shape(1) + integer :: output_shape(2) + real, allocatable :: gradient(:) + real, allocatable :: output(:,:) + + contains + + procedure :: backward + procedure :: forward + procedure :: init + + end type reshape2d_layer + + interface reshape2d_layer + pure module function reshape2d_layer_cons(output_shape) result(res) + !! This function returns the `reshape_layer` instance. + integer, intent(in) :: output_shape(2) + !! The shape of the output + type(reshape2d_layer) :: res + !! reshape_layer instance + end function reshape2d_layer_cons + end interface reshape2d_layer + + interface + + pure module subroutine backward(self, input, gradient) + !! Apply the backward pass for the reshape2d layer. + !! This is just flattening to a rank-1 array. + class(reshape2d_layer), intent(in out) :: self + !! Dense layer instance + real, intent(in) :: input(:) + !! Input from the previous layer + real, intent(in) :: gradient(:,:) + !! Gradient from the next layer + end subroutine backward + + pure module subroutine forward(self, input) + !! Apply the forward pass for the reshape2d layer. + !! This is just a reshape from rank-1 to rank-2 array. + class(reshape2d_layer), intent(in out) :: self + !! Dense layer instance + real, intent(in) :: input(:) + !! Input from the previous layer + end subroutine forward + + module subroutine init(self, input_shape) + !! Initialize the layer data structures. + !! + !! This is a deferred procedure from the `base_layer` abstract type. + class(reshape2d_layer), intent(in out) :: self + !! Dense layer instance + integer, intent(in) :: input_shape(:) + !! Shape of the input layer + end subroutine init + + end interface + + end module nf_reshape2d_layer + \ No newline at end of file diff --git a/src/nf/nf_reshape2d_layer_submodule.f90 b/src/nf/nf_reshape2d_layer_submodule.f90 new file mode 100644 index 00000000..487d5cb8 --- /dev/null +++ b/src/nf/nf_reshape2d_layer_submodule.f90 @@ -0,0 +1,50 @@ +submodule(nf_reshape2d_layer) nf_reshape2d_layer_submodule + + use nf_base_layer, only: base_layer + + implicit none + +contains + + pure module function reshape2d_layer_cons(output_shape) result(res) + integer, intent(in) :: output_shape(2) + type(reshape2d_layer) :: res + res % output_shape = output_shape + end function reshape2d_layer_cons + + + pure module subroutine backward(self, input, gradient) + class(reshape2d_layer), intent(in out) :: self + real, intent(in) :: input(:) + real, intent(in) :: gradient(:,:) + ! The `input` dummy argument is not used but nevertheless declared + ! because the abstract type requires it. + self % gradient = pack(gradient, .true.) + end subroutine backward + + + pure module subroutine forward(self, input) + class(reshape2d_layer), intent(in out) :: self + real, intent(in) :: input(:) + self % output = reshape(input, self % output_shape) + end subroutine forward + + + module subroutine init(self, input_shape) + class(reshape2d_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + self % input_shape = input_shape + + allocate(self % gradient(input_shape(1))) + self % gradient = 0 + + allocate(self % output( & + self % output_shape(1), & + self % output_shape(2) & + )) + self % output = 0 + + end subroutine init + +end submodule nf_reshape2d_layer_submodule From cf2caf6502484faf993ec873abc79f499a4be535 Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Sun, 16 Feb 2025 19:32:29 +0100 Subject: [PATCH 02/26] added tests; to note that they don't work --- src/nf/nf_layer_submodule.f90 | 4 ++ src/nf/nf_network_submodule.f90 | 15 +++-- test/CMakeLists.txt | 2 + test/test_maxpool1d_layer.f90 | 100 ++++++++++++++++++++++++++++++++ test/test_reshape2d_layer.f90 | 54 +++++++++++++++++ 5 files changed, 169 insertions(+), 6 deletions(-) create mode 100644 test/test_maxpool1d_layer.f90 create mode 100644 test/test_reshape2d_layer.f90 diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index 60478acf..2dcde84c 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -340,6 +340,10 @@ impure elemental module subroutine init(self, input) self % layer_shape = shape(this_layer % output) type is(maxpool2d_layer) self % layer_shape = shape(this_layer % output) + type is(locally_connected_1d_layer) + self % layer_shape = shape(this_layer % output) + type is(maxpool1d_layer) + self % layer_shape = shape(this_layer % output) type is(flatten_layer) self % layer_shape = shape(this_layer % output) end select diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 0e06287f..fb6142b7 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -73,18 +73,21 @@ module function network_from_layers(layers) result(res) type is(conv2d_layer) res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] n = n + 1 + !type is(locally_connected_1d_layer) + !res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] + !n = n + 1 type is(maxpool2d_layer) res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] n = n + 1 type is(reshape3d_layer) res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] n = n + 1 - type is(maxpool1d_layer) - res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] - n = n + 1 - type is(reshape2d_layer) - res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] - n = n + 1 + !type is(maxpool1d_layer) + ! res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] + ! n = n + 1 + !type is(reshape2d_layer) + ! res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] + ! n = n + 1 class default n = n + 1 end select diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 35954894..266fb877 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -5,10 +5,12 @@ foreach(execid parametric_activation dense_layer conv2d_layer + maxpool1d_layer maxpool2d_layer flatten_layer insert_flatten reshape_layer + reshape2d_layer dense_network get_set_network_params conv2d_network diff --git a/test/test_maxpool1d_layer.f90 b/test/test_maxpool1d_layer.f90 new file mode 100644 index 00000000..a5691272 --- /dev/null +++ b/test/test_maxpool1d_layer.f90 @@ -0,0 +1,100 @@ +program test_maxpool1d_layer + + use iso_fortran_env, only: stderr => error_unit + use nf, only: maxpool1d, input, layer + use nf_input2d_layer, only: input2d_layer + use nf_maxpool1d_layer, only: maxpool1d_layer + + implicit none + + type(layer) :: maxpool_layer, input_layer + integer, parameter :: pool_size = 2, stride = 2 + integer, parameter :: channels = 3, length = 32 + integer, parameter :: input_shape(2) = [channels, length] + integer, parameter :: output_shape(2) = [channels, length / 2] + real, allocatable :: sample_input(:,:), output(:,:), gradient(:,:) + integer :: i + logical :: ok = .true., gradient_ok = .true. + + maxpool_layer = maxpool1d(pool_size) + + if (.not. maxpool_layer % name == 'maxpool1d') then + ok = .false. + write(stderr, '(a)') 'maxpool1d layer has its name set correctly.. failed' + end if + + if (maxpool_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'maxpool1d layer should not be marked as initialized yet.. failed' + end if + + input_layer = input(channels, length) + call maxpool_layer % init(input_layer) + + if (.not. maxpool_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'maxpool1d layer should now be marked as initialized.. failed' + end if + + if (.not. all(maxpool_layer % input_layer_shape == input_shape)) then + ok = .false. + write(stderr, '(a)') 'maxpool1d layer input layer shape should be correct.. failed' + end if + + if (.not. all(maxpool_layer % layer_shape == output_shape)) then + ok = .false. + write(stderr, '(a)') 'maxpool1d layer output layer shape should be correct.. failed' + end if + + ! Allocate and initialize sample input data + allocate(sample_input(channels, length)) + do concurrent(i = 1:length) + sample_input(:,i) = i + end do + + select type(this_layer => input_layer % p); type is(input2d_layer) + call this_layer % set(sample_input) + end select + + call maxpool_layer % forward(input_layer) + call maxpool_layer % get_output(output) + + do i = 1, length / 2 + ! Since input is i, maxpool1d output must be stride*i + if (.not. all(output(:,i) == stride * i)) then + ok = .false. + write(stderr, '(a)') 'maxpool1d layer forward pass correctly propagates the max value.. failed' + end if + end do + + ! Test the backward pass + ! Allocate and initialize the downstream gradient field + allocate(gradient, source=output) + + ! Make a backward pass + call maxpool_layer % backward(input_layer, gradient) + + select type(this_layer => maxpool_layer % p); type is(maxpool1d_layer) + do i = 1, length + if (mod(i,2) == 0) then + if (.not. all(sample_input(:,i) == this_layer % gradient(:,i))) gradient_ok = .false. + else + if (.not. all(this_layer % gradient(:,i) == 0)) gradient_ok = .false. + end if + end do + end select + + if (.not. gradient_ok) then + ok = .false. + write(stderr, '(a)') 'maxpool1d layer backward pass produces the correct dL/dx.. failed' + end if + + if (ok) then + print '(a)', 'test_maxpool1d_layer: All tests passed.' + else + write(stderr, '(a)') 'test_maxpool1d_layer: One or more tests failed.' + stop 1 + end if + + end program test_maxpool1d_layer + \ No newline at end of file diff --git a/test/test_reshape2d_layer.f90 b/test/test_reshape2d_layer.f90 new file mode 100644 index 00000000..a77ce23d --- /dev/null +++ b/test/test_reshape2d_layer.f90 @@ -0,0 +1,54 @@ +program test_reshape2d_layer + + use iso_fortran_env, only: stderr => error_unit + use nf, only: input, network, reshape2d_layer => reshape2d + use nf_datasets, only: download_and_unpack, keras_reshape_url + + implicit none + + type(network) :: net + real, allocatable :: sample_input(:), output(:,:) + integer, parameter :: output_shape(2) = [32, 32] + integer, parameter :: input_size = product(output_shape) + character(*), parameter :: keras_reshape_path = 'keras_reshape.h5' + logical :: file_exists + logical :: ok = .true. + + ! Create the network + net = network([ & + input(input_size), & + reshape2d_layer(output_shape) & + ]) + + if (.not. size(net % layers) == 2) then + write(stderr, '(a)') 'the network should have 2 layers.. failed' + ok = .false. + end if + + ! Initialize test data + allocate(sample_input(input_size)) + call random_number(sample_input) + + ! Propagate forward and get the output + call net % forward(sample_input) + call net % layers(2) % get_output(output) + + if (.not. all(shape(output) == output_shape)) then + write(stderr, '(a)') 'the reshape layer produces expected output shape.. failed' + ok = .false. + end if + + if (.not. all(reshape(sample_input, output_shape) == output)) then + write(stderr, '(a)') 'the reshape layer produces expected output values.. failed' + ok = .false. + end if + + if (ok) then + print '(a)', 'test_reshape2d_layer: All tests passed.' + else + write(stderr, '(a)') 'test_reshape2d_layer: One or more tests failed.' + stop 1 + end if + + end program test_reshape2d_layer + \ No newline at end of file From eb4079dde480a27ac421666cf2ba5ff54f8ab5dc Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Mon, 17 Feb 2025 16:01:34 +0100 Subject: [PATCH 03/26] Now reshape2d works, maxpool still not --- src/nf/nf_layer_submodule.f90 | 12 ++ src/nf/nf_locally_connected_1d_submodule.f90 | 162 ++++++++----------- test/test_reshape2d_layer.f90 | 105 ++++++------ 3 files changed, 134 insertions(+), 145 deletions(-) diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index 2dcde84c..27c6314b 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -87,6 +87,12 @@ pure module subroutine backward_2d(self, previous, gradient) type is(locally_connected_1d_layer) call this_layer % backward(prev_layer % output, gradient) end select + + type is(reshape2d_layer) + select type(prev_layer => previous % p) + type is(input1d_layer) + call this_layer % backward(prev_layer % output, gradient) + end select end select @@ -248,6 +254,12 @@ pure module subroutine forward(self, input) type is(flatten_layer) call this_layer % forward(prev_layer % output) end select + + type is(reshape2d_layer) + select type(prev_layer => input % p) + type is(input1d_layer) + call this_layer % forward(prev_layer % output) + end select end select diff --git a/src/nf/nf_locally_connected_1d_submodule.f90 b/src/nf/nf_locally_connected_1d_submodule.f90 index e3715dd6..2a65b7f6 100644 --- a/src/nf/nf_locally_connected_1d_submodule.f90 +++ b/src/nf/nf_locally_connected_1d_submodule.f90 @@ -18,7 +18,6 @@ module function locally_connected_1d_layer_cons(filters, kernel_size, activation res % filters = filters res % activation_name = activation % get_name() allocate( res % activation, source = activation ) - end function locally_connected_1d_layer_cons module subroutine init(self, input_shape) @@ -29,16 +28,14 @@ module subroutine init(self, input_shape) self % channels = input_shape(1) self % width = input_shape(2) - self % kernel_size + 1 - ! Output of shape filters x width + ! Output of shape: filters x width allocate(self % output(self % filters, self % width)) self % output = 0 - ! Kernel of shape filters x channels x kernel_size + ! Kernel of shape: filters x channels x kernel_size allocate(self % kernel(self % filters, self % channels, self % kernel_size)) - - ! Initialize the kernel with random values with a normal distribution call random_normal(self % kernel) - self % kernel = self % kernel / self % kernel_size ** 2 + self % kernel = self % kernel / real(self % kernel_size**2) allocate(self % biases(self % filters)) self % biases = 0 @@ -61,113 +58,93 @@ pure module subroutine forward(self, input) implicit none class(locally_connected_1d_layer), intent(in out) :: self real, intent(in) :: input(:,:) - integer :: input_width, input_channels - integer :: i, n, i_out - integer :: iws, iwe - integer :: half_window + integer :: input_channels, input_width + integer :: j, n + integer :: iws, iwe, half_window - ! Get input dimensions input_channels = size(input, dim=1) input_width = size(input, dim=2) - - ! For a kernel of odd size, half_window = kernel_size / 2 (integer division) half_window = self % kernel_size / 2 - ! Loop over output indices rather than input indices. - do i_out = 1, self % width - ! Compute the corresponding center index in the input. - i = i_out + half_window - - ! Define the window in the input corresponding to the filter kernel - iws = i - half_window - iwe = i + half_window + ! Loop over output positions. + do j = 1, self % width + ! Compute the input window corresponding to output index j. + ! In forward: center index = j + half_window, so window = indices j to j+kernel_size-1. + iws = j + iwe = j + self % kernel_size - 1 - ! Compute the inner tensor product (sum of element-wise products) - ! for each filter across all channels and positions in the kernel. - do concurrent(n = 1:self % filters) - self % z(n, i_out) = sum(self % kernel(n, :, :) * input(:, iws:iwe)) + ! For each filter, compute the convolution (inner product over channels and kernel width). + do concurrent (n = 1:self % filters) + self % z(n, j) = sum(self % kernel(n, :, :) * input(:, iws:iwe)) end do ! Add the bias for each filter. - self % z(:, i_out) = self % z(:, i_out) + self % biases + self % z(:, j) = self % z(:, j) + self % biases end do - ! Apply the activation function to get the final output. + ! Apply the activation function. self % output = self % activation % eval(self % z) end subroutine forward - pure module subroutine backward(self, input, gradient) implicit none class(locally_connected_1d_layer), intent(in out) :: self - real, intent(in) :: input(:,:) ! shape: (channels, width) - real, intent(in) :: gradient(:,:) ! shape: (filters, width) - - ! Local gradient arrays: - real :: db(self % filters) - real :: dw(self % filters, self % channels, self % kernel_size) - real :: gdz(self % filters, size(input, 2)) - - integer :: i, n, k - integer :: input_channels, input_width - integer :: istart, iend - integer :: iws, iwe - integer :: half_window - - ! Get input dimensions. + ! 'input' has shape: (channels, input_width) + ! 'gradient' (dL/dy) has shape: (filters, output_width) + real, intent(in) :: input(:,:) + real, intent(in) :: gradient(:,:) + + integer :: input_channels, input_width, output_width + integer :: j, n, k + integer :: iws, iwe, half_window + real :: gdz_val + + ! Local arrays to accumulate gradients. + real :: gdz(self % filters, self % width) ! local gradient (dL/dz) + real :: db_local(self % filters) + real :: dw_local(self % filters, self % channels, self % kernel_size) + + ! Determine dimensions. input_channels = size(input, dim=1) input_width = size(input, dim=2) - - ! For an odd-sized kernel, half_window = kernel_size / 2. + output_width = self % width ! Note: output_width = input_width - kernel_size + 1 + half_window = self % kernel_size / 2 - - ! Define the valid output range so that the full input window is available. - istart = half_window + 1 - iend = input_width - half_window - - !--------------------------------------------------------------------- - ! Compute the local gradient: gdz = (dL/dy) * sigma'(z) - ! We assume self%z stores the pre-activation values from the forward pass. - gdz = 0.0 - gdz(:, istart:iend) = gradient(:, istart:iend) * self % activation % eval_prime(self % z(:, istart:iend)) - - !--------------------------------------------------------------------- - ! Compute gradient with respect to biases: - ! dL/db(n) = sum_{i in valid range} gdz(n, i) - do concurrent (n = 1:self % filters) - db(n) = sum(gdz(n, istart:iend)) + + !--- Compute the local gradient gdz = (dL/dy) * sigma'(z) for each output. + do j = 1, output_width + gdz(:, j) = gradient(:, j) * self % activation % eval_prime(self % z(:, j)) end do - - ! Initialize weight gradient and input gradient accumulators. - dw = 0.0 - self % gradient = 0.0 ! This array is assumed preallocated to shape (channels, width) - - !--------------------------------------------------------------------- - ! Accumulate gradients over valid output positions. - ! For each output position i, determine the corresponding input window indices. - do concurrent (n = 1:self % filters, & - k = 1:self % channels, & - i = istart:iend) - ! The input window corresponding to output index i: - iws = i - half_window - iwe = i + half_window - - ! Weight gradient (dL/dw): - ! For each kernel element, the contribution is the product of the input in the window - ! and the local gradient at the output position i. - dw(n, k, :) = dw(n, k, :) + input(k, iws:iwe) * gdz(n, i) - - ! Input gradient (dL/dx): - ! Distribute the effect of the output gradient back onto the input window, - ! weighted by the kernel weights. - self % gradient(k, iws:iwe) = self % gradient(k, iws:iwe) + self % kernel(n, k, :) * gdz(n, i) + + !--- Compute bias gradients: db(n) = sum_j gdz(n, j) + do n = 1, self % filters + db_local(n) = sum(gdz(n, :)) end do - - !--------------------------------------------------------------------- - ! Accumulate the computed gradients into the layer's stored gradients. - self % dw = self % dw + dw - self % db = self % db + db - + + !--- Initialize weight gradient and input gradient accumulators. + dw_local = 0.0 + self % gradient = 0.0 + + !--- Accumulate gradients over each output position. + ! In the forward pass the window for output index j was: + ! iws = j, iwe = j + kernel_size - 1. + do n = 1, self % filters + do j = 1, output_width + iws = j + iwe = j + self % kernel_size - 1 + do k = 1, self % channels + ! Weight gradient: accumulate contribution from the input window. + dw_local(n, k, :) = dw_local(n, k, :) + input(k, iws:iwe) * gdz(n, j) + ! Input gradient: propagate gradient back to the input window. + self % gradient(k, iws:iwe) = self % gradient(k, iws:iwe) + self % kernel(n, k, :) * gdz(n, j) + end do + end do + end do + + !--- Update stored gradients. + self % dw = self % dw + dw_local + self % db = self % db + db_local + end subroutine backward pure module function get_num_params(self) result(num_params) @@ -197,11 +174,10 @@ module subroutine set_params(self, params) real, intent(in) :: params(:) if (size(params) /= self % get_num_params()) then - error stop 'locally_connected_1d % set_params: Number of parameters does not match' + error stop 'locally_connected_1d_layer % set_params: Number of parameters does not match' end if self % kernel = reshape(params(:product(shape(self % kernel))), shape(self % kernel)) - associate(n => product(shape(self % kernel))) self % biases = params(n + 1 : n + self % filters) end associate diff --git a/test/test_reshape2d_layer.f90 b/test/test_reshape2d_layer.f90 index a77ce23d..52817eac 100644 --- a/test/test_reshape2d_layer.f90 +++ b/test/test_reshape2d_layer.f90 @@ -1,54 +1,55 @@ program test_reshape2d_layer - use iso_fortran_env, only: stderr => error_unit - use nf, only: input, network, reshape2d_layer => reshape2d - use nf_datasets, only: download_and_unpack, keras_reshape_url - - implicit none - - type(network) :: net - real, allocatable :: sample_input(:), output(:,:) - integer, parameter :: output_shape(2) = [32, 32] - integer, parameter :: input_size = product(output_shape) - character(*), parameter :: keras_reshape_path = 'keras_reshape.h5' - logical :: file_exists - logical :: ok = .true. - - ! Create the network - net = network([ & - input(input_size), & - reshape2d_layer(output_shape) & - ]) - - if (.not. size(net % layers) == 2) then - write(stderr, '(a)') 'the network should have 2 layers.. failed' - ok = .false. - end if - - ! Initialize test data - allocate(sample_input(input_size)) - call random_number(sample_input) - - ! Propagate forward and get the output - call net % forward(sample_input) - call net % layers(2) % get_output(output) - - if (.not. all(shape(output) == output_shape)) then - write(stderr, '(a)') 'the reshape layer produces expected output shape.. failed' - ok = .false. - end if - - if (.not. all(reshape(sample_input, output_shape) == output)) then - write(stderr, '(a)') 'the reshape layer produces expected output values.. failed' - ok = .false. - end if - - if (ok) then - print '(a)', 'test_reshape2d_layer: All tests passed.' - else - write(stderr, '(a)') 'test_reshape2d_layer: One or more tests failed.' - stop 1 - end if - - end program test_reshape2d_layer - \ No newline at end of file + use iso_fortran_env, only: stderr => error_unit + use nf, only: input, network, reshape2d_layer => reshape2d + use nf_datasets, only: download_and_unpack, keras_reshape_url + + implicit none + + type(network) :: net + real, allocatable :: sample_input(:), output(:,:) + integer, parameter :: output_shape(2) = [4,4] + integer, parameter :: input_size = product(output_shape) + character(*), parameter :: keras_reshape_path = 'keras_reshape.h5' + logical :: file_exists + logical :: ok = .true. + + ! Create the network + net = network([ & + input(input_size), & + reshape2d_layer(output_shape) & + ]) + + if (.not. size(net % layers) == 2) then + write(stderr, '(a)') 'the network should have 2 layers.. failed' + ok = .false. + end if + + ! Initialize test data + allocate(sample_input(input_size)) + call random_number(sample_input) + + ! Propagate forward and get the output + call net % forward(sample_input) + call net % layers(2) % get_output(output) + + ! Check shape of the output + if (.not. all(shape(output) == output_shape)) then + write(stderr, '(a)') 'the reshape layer produces expected output shape.. failed' + ok = .false. + end if + + ! Check if reshaped input matches output + if (.not. all(reshape(sample_input, output_shape) == output)) then + write(stderr, '(a)') 'the reshape layer produces expected output values.. failed' + ok = .false. + end if + + if (ok) then + print '(a)', 'test_reshape2d_layer: All tests passed.' + else + write(stderr, '(a)') 'test_reshape2d_layer: One or more tests failed.' + stop 1 + end if + +end program test_reshape2d_layer From 6842f3a57db18f2e40370ca5aec05efe591e742f Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Sat, 22 Feb 2025 10:45:28 +0100 Subject: [PATCH 04/26] Saving changes before rebasing --- example/cnn_mnist.f90 | 4 ++-- src/nf/nf_network_submodule.f90 | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/example/cnn_mnist.f90 b/example/cnn_mnist.f90 index bec50b80..bf918c8b 100644 --- a/example/cnn_mnist.f90 +++ b/example/cnn_mnist.f90 @@ -35,9 +35,9 @@ program cnn_mnist call net % train( & training_images, & label_digits(training_labels), & - batch_size=128, & + batch_size=16, & epochs=1, & - optimizer=sgd(learning_rate=3.) & + optimizer=sgd(learning_rate=0.003) & ) print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( & diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index fb6142b7..0a5e0efa 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -82,9 +82,9 @@ module function network_from_layers(layers) result(res) type is(reshape3d_layer) res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] n = n + 1 - !type is(maxpool1d_layer) - ! res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] - ! n = n + 1 + type is(maxpool1d_layer) + res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] + n = n + 1 !type is(reshape2d_layer) ! res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] ! n = n + 1 From b942637126ed22df144f7c81ab693a28bb0a381f Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Sat, 22 Feb 2025 10:56:40 +0100 Subject: [PATCH 05/26] Resolved merge conflicts --- src/nf/nf_layer_submodule.f90 | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index 6f7e3a85..40cc6320 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -367,6 +367,16 @@ pure module subroutine get_output_2d(self, output) type is(input2d_layer) allocate(output, source=this_layer % output) + type is(maxpool1d_layer) + allocate(output, source=this_layer % output) + type is(locally_connected_1d_layer) + allocate(output, source=this_layer % output) + type is(reshape2d_layer) + allocate(output, source=this_layer % output) + type is(linear2d_layer) + allocate(output, source=this_layer % output) + type is(self_attention_layer) + allocate(output, source=this_layer % output) class default error stop '2-d output can only be read from an input2d or linear2d layer.' @@ -557,6 +567,8 @@ module function get_gradients(self) result(gradients) ! No gradients to get. type is (flatten_layer) ! No gradients to get. + type is (reshape2d_layer) + ! No parameters to get. type is (reshape3d_layer) ! No gradients to get. type is (linear2d_layer) From 5d62b13129f17477a61d71ae731caa4bebe66683 Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Sat, 22 Feb 2025 16:54:24 +0100 Subject: [PATCH 06/26] Bug fixed; Added conv1d; Conv1d and maxpool backward still not working --- CMakeLists.txt | 2 + example/cnn_mnist_1d.f90 | 4 +- src/nf.f90 | 12 +- src/nf/nf_conv1d_layer.f90 | 119 +++++++++++++ src/nf/nf_conv1d_layer_submodule.f90 | 187 +++++++++++++++++++++ src/nf/nf_layer_constructors.f90 | 29 ++++ src/nf/nf_layer_constructors_submodule.f90 | 26 +++ src/nf/nf_layer_submodule.f90 | 69 +++++++- src/nf/nf_maxpool1d_layer_submodule.f90 | 78 ++++----- src/nf/nf_network_submodule.f90 | 2 +- test/CMakeLists.txt | 3 + test/test_conv1d_layer.f90 | 85 ++++++++++ test/test_conv1d_network.f90 | 153 +++++++++++++++++ test/test_locally_connected_1d_layer.f90 | 85 ++++++++++ test/test_maxpool1d_layer.f90 | 183 ++++++++++---------- 15 files changed, 889 insertions(+), 148 deletions(-) create mode 100644 src/nf/nf_conv1d_layer.f90 create mode 100644 src/nf/nf_conv1d_layer_submodule.f90 create mode 100644 test/test_conv1d_layer.f90 create mode 100644 test/test_conv1d_network.f90 create mode 100644 test/test_locally_connected_1d_layer.f90 diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c39eb07..3925c9fe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,6 +18,8 @@ add_library(neural-fortran src/nf.f90 src/nf/nf_activation.f90 src/nf/nf_base_layer.f90 + src/nf/nf_conv1d_layer.f90 + src/nf/nf_conv1d_layer_submodule.f90 src/nf/nf_conv2d_layer.f90 src/nf/nf_conv2d_layer_submodule.f90 src/nf/nf_cross_attention_layer.f90 diff --git a/example/cnn_mnist_1d.f90 b/example/cnn_mnist_1d.f90 index f8b50ae5..4157510a 100644 --- a/example/cnn_mnist_1d.f90 +++ b/example/cnn_mnist_1d.f90 @@ -1,4 +1,4 @@ -program cnn_mnist +program cnn_mnist_1d use nf, only: network, sgd, & input, conv2d, maxpool1d, maxpool2d, flatten, dense, reshape, reshape2d, locally_connected_1d, & @@ -63,5 +63,5 @@ real function accuracy(net, x, y) accuracy = real(good) / size(x, dim=2) end function accuracy - end program cnn_mnist + end program cnn_mnist_1d \ No newline at end of file diff --git a/src/nf.f90 b/src/nf.f90 index 3e0190d1..d7ad010a 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -3,16 +3,8 @@ module nf use nf_datasets_mnist, only: label_digits, load_mnist use nf_layer, only: layer use nf_layer_constructors, only: & - conv2d, dense, flatten, input, maxpool1d, maxpool2d, reshape, reshape2d, locally_connected_1d - conv2d, & - dense, & - dropout, & - flatten, & - input, & - linear2d, & - maxpool2d, & - reshape, & - self_attention + conv1d, conv2d, dense, dropout, flatten, input, linear2d, locally_connected_1d, & + maxpool1d, maxpool2d, reshape, reshape2d, self_attention use nf_loss, only: mse, quadratic use nf_metrics, only: corr, maxabs use nf_network, only: network diff --git a/src/nf/nf_conv1d_layer.f90 b/src/nf/nf_conv1d_layer.f90 new file mode 100644 index 00000000..794d441e --- /dev/null +++ b/src/nf/nf_conv1d_layer.f90 @@ -0,0 +1,119 @@ +module nf_conv1d_layer + !! This modules provides a 1-d convolutional `conv1d` type. + + use nf_activation, only: activation_function + use nf_base_layer, only: base_layer + implicit none + + private + public :: conv1d_layer + + type, extends(base_layer) :: conv1d_layer + + integer :: width + integer :: height + integer :: channels + integer :: kernel_size + integer :: filters + + real, allocatable :: biases(:) ! size(filters) + real, allocatable :: kernel(:,:,:) ! filters x channels x window x window + real, allocatable :: output(:,:) ! filters x output_width * output_height + real, allocatable :: z(:,:) ! kernel .dot. input + bias + + real, allocatable :: dw(:,:,:) ! weight (kernel) gradients + real, allocatable :: db(:) ! bias gradients + real, allocatable :: gradient(:,:) + + class(activation_function), allocatable :: activation + + contains + + procedure :: forward + procedure :: backward + procedure :: get_gradients + procedure :: get_num_params + procedure :: get_params + procedure :: init + procedure :: set_params + + end type conv1d_layer + + interface conv1d_layer + module function conv1d_layer_cons(filters, kernel_size, activation) & + result(res) + !! `conv1d_layer` constructor function + integer, intent(in) :: filters + integer, intent(in) :: kernel_size + class(activation_function), intent(in) :: activation + type(conv1d_layer) :: res + end function conv1d_layer_cons + end interface conv1d_layer + + interface + + module subroutine init(self, input_shape) + !! Initialize the layer data structures. + !! + !! This is a deferred procedure from the `base_layer` abstract type. + class(conv1d_layer), intent(in out) :: self + !! A `conv1d_layer` instance + integer, intent(in) :: input_shape(:) + !! Input layer dimensions + end subroutine init + + pure module subroutine forward(self, input) + !! Apply a forward pass on the `conv1d` layer. + class(conv1d_layer), intent(in out) :: self + !! A `conv1d_layer` instance + real, intent(in) :: input(:,:) + !! Input data + end subroutine forward + + pure module subroutine backward(self, input, gradient) + !! Apply a backward pass on the `conv1d` layer. + class(conv1d_layer), intent(in out) :: self + !! A `conv1d_layer` instance + real, intent(in) :: input(:,:) + !! Input data (previous layer) + real, intent(in) :: gradient(:,:) + !! Gradient (next layer) + end subroutine backward + + pure module function get_num_params(self) result(num_params) + !! Get the number of parameters in the layer. + class(conv1d_layer), intent(in) :: self + !! A `conv1d_layer` instance + integer :: num_params + !! Number of parameters + end function get_num_params + + module function get_params(self) result(params) + !! Return the parameters (weights and biases) of this layer. + !! The parameters are ordered as weights first, biases second. + class(conv1d_layer), intent(in), target :: self + !! A `conv1d_layer` instance + real, allocatable :: params(:) + !! Parameters to get + end function get_params + + module function get_gradients(self) result(gradients) + !! Return the gradients of this layer. + !! The gradients are ordered as weights first, biases second. + class(conv1d_layer), intent(in), target :: self + !! A `conv1d_layer` instance + real, allocatable :: gradients(:) + !! Gradients to get + end function get_gradients + + module subroutine set_params(self, params) + !! Set the parameters of the layer. + class(conv1d_layer), intent(in out) :: self + !! A `conv1d_layer` instance + real, intent(in) :: params(:) + !! Parameters to set + end subroutine set_params + + end interface + +end module nf_conv1d_layer diff --git a/src/nf/nf_conv1d_layer_submodule.f90 b/src/nf/nf_conv1d_layer_submodule.f90 new file mode 100644 index 00000000..48ec7901 --- /dev/null +++ b/src/nf/nf_conv1d_layer_submodule.f90 @@ -0,0 +1,187 @@ +submodule(nf_conv1d_layer) nf_conv1d_layer_submodule + + use nf_activation, only: activation_function + use nf_random, only: random_normal + + implicit none + +contains + + module function conv1d_layer_cons(filters, kernel_size, activation) result(res) + implicit none + integer, intent(in) :: filters + integer, intent(in) :: kernel_size + class(activation_function), intent(in) :: activation + type(conv1d_layer) :: res + + res % kernel_size = kernel_size + res % filters = filters + res % activation_name = activation % get_name() + allocate( res % activation, source = activation ) + end function conv1d_layer_cons + + module subroutine init(self, input_shape) + implicit none + class(conv1d_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + self % channels = input_shape(1) + self % width = input_shape(2) - self % kernel_size + 1 + + ! Output of shape: filters x width + allocate(self % output(self % filters, self % width)) + self % output = 0 + + ! Kernel of shape: filters x channels x kernel_size + allocate(self % kernel(self % filters, self % channels, self % kernel_size)) + call random_normal(self % kernel) + self % kernel = self % kernel / real(self % kernel_size**2) + + allocate(self % biases(self % filters)) + self % biases = 0 + + allocate(self % z, mold=self % output) + self % z = 0 + + allocate(self % gradient(input_shape(1), input_shape(2))) + self % gradient = 0 + + allocate(self % dw, mold=self % kernel) + self % dw = 0 + + allocate(self % db, mold=self % biases) + self % db = 0 + + end subroutine init + + pure module subroutine forward(self, input) + implicit none + class(conv1d_layer), intent(in out) :: self + real, intent(in) :: input(:,:) + integer :: input_channels, input_width + integer :: j, n + integer :: iws, iwe, half_window + + input_channels = size(input, dim=1) + input_width = size(input, dim=2) + half_window = self % kernel_size / 2 + + ! Loop over output positions. + do j = 1, self % width + ! Compute the input window corresponding to output index j. + ! In forward: center index = j + half_window, so window = indices j to j+kernel_size-1. + iws = j + iwe = j + self % kernel_size - 1 + + ! For each filter, compute the convolution (inner product over channels and kernel width). + do concurrent (n = 1:self % filters) + self % z(n, j) = sum(self % kernel(n, :, :) * input(:, iws:iwe)) + end do + + ! Add the bias for each filter. + self % z(:, j) = self % z(:, j) + self % biases + end do + + ! Apply the activation function. + self % output = self % activation % eval(self % z) + end subroutine forward + + pure module subroutine backward(self, input, gradient) + implicit none + class(conv1d_layer), intent(in out) :: self + ! 'input' has shape: (channels, input_width) + ! 'gradient' (dL/dy) has shape: (filters, output_width) + real, intent(in) :: input(:,:) + real, intent(in) :: gradient(:,:) + + integer :: input_channels, input_width, output_width + integer :: j, n, k + integer :: iws, iwe, half_window + real :: gdz_val + + ! Local arrays to accumulate gradients. + real :: gdz(self % filters, self % width) ! local gradient (dL/dz) + real :: db_local(self % filters) + real :: dw_local(self % filters, self % channels, self % kernel_size) + + ! Determine dimensions. + input_channels = size(input, dim=1) + input_width = size(input, dim=2) + output_width = self % width ! Note: output_width = input_width - kernel_size + 1 + + half_window = self % kernel_size / 2 + + !--- Compute the local gradient gdz = (dL/dy) * sigma'(z) for each output. + do j = 1, output_width + gdz(:, j) = gradient(:, j) * self % activation % eval_prime(self % z(:, j)) + end do + + !--- Compute bias gradients: db(n) = sum_j gdz(n, j) + do n = 1, self % filters + db_local(n) = sum(gdz(n, :)) + end do + + !--- Initialize weight gradient and input gradient accumulators. + dw_local = 0.0 + self % gradient = 0.0 + + !--- Accumulate gradients over each output position. + ! In the forward pass the window for output index j was: + ! iws = j, iwe = j + kernel_size - 1. + do n = 1, self % filters + do j = 1, output_width + iws = j + iwe = j + self % kernel_size - 1 + do k = 1, self % channels + ! Weight gradient: accumulate contribution from the input window. + dw_local(n, k, :) = dw_local(n, k, :) + input(k, iws:iwe) * gdz(n, j) + ! Input gradient: propagate gradient back to the input window. + self % gradient(k, iws:iwe) = self % gradient(k, iws:iwe) + self % kernel(n, k, :) * gdz(n, j) + end do + end do + end do + + !--- Update stored gradients. + self % dw = self % dw + dw_local + self % db = self % db + db_local + + end subroutine backward + + pure module function get_num_params(self) result(num_params) + class(conv1d_layer), intent(in) :: self + integer :: num_params + num_params = product(shape(self % kernel)) + size(self % biases) + end function get_num_params + + module function get_params(self) result(params) + class(conv1d_layer), intent(in), target :: self + real, allocatable :: params(:) + real, pointer :: w_(:) => null() + w_(1:size(self % kernel)) => self % kernel + params = [ w_, self % biases ] + end function get_params + + module function get_gradients(self) result(gradients) + class(conv1d_layer), intent(in), target :: self + real, allocatable :: gradients(:) + real, pointer :: dw_(:) => null() + dw_(1:size(self % dw)) => self % dw + gradients = [ dw_, self % db ] + end function get_gradients + + module subroutine set_params(self, params) + class(conv1d_layer), intent(in out) :: self + real, intent(in) :: params(:) + + if (size(params) /= self % get_num_params()) then + error stop 'conv1d_layer % set_params: Number of parameters does not match' + end if + + self % kernel = reshape(params(:product(shape(self % kernel))), shape(self % kernel)) + associate(n => product(shape(self % kernel))) + self % biases = params(n + 1 : n + self % filters) + end associate + + end subroutine set_params + +end submodule nf_conv1d_layer_submodule diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index a6499984..3dc13c50 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -9,6 +9,7 @@ module nf_layer_constructors private public :: & + conv1d, & conv2d, & dense, & dropout, & @@ -152,6 +153,34 @@ module function flatten() result(res) !! Resulting layer instance end function flatten + module function conv1d(filters, kernel_size, activation) result(res) + !! CHANGE THE COMMENTS + !! 2-d convolutional layer constructor. + !! + !! This layer is for building 2-d convolutional network. + !! Although the established convention is to call these layers 2-d, + !! the shape of the data is actuall 3-d: image width, image height, + !! and the number of channels. + !! A conv2d layer must not be the first layer in the network. + !! + !! Example: + !! + !! ``` + !! use nf, only :: conv2d, layer + !! type(layer) :: conv2d_layer + !! conv2d_layer = dense(filters=32, kernel_size=3) + !! conv2d_layer = dense(filters=32, kernel_size=3, activation='relu') + !! ``` + integer, intent(in) :: filters + !! Number of filters in the output of the layer + integer, intent(in) :: kernel_size + !! Width of the convolution window, commonly 3 or 5 + class(activation_function), intent(in), optional :: activation + !! Activation function (default sigmoid) + type(layer) :: res + !! Resulting layer instance + end function conv1d + module function conv2d(filters, kernel_size, activation) result(res) !! 2-d convolutional layer constructor. !! diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index 7f6f33a8..0c653988 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -1,6 +1,7 @@ submodule(nf_layer_constructors) nf_layer_constructors_submodule use nf_layer, only: layer + use nf_conv1d_layer, only: conv1d_layer use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer use nf_dropout_layer, only: dropout_layer @@ -21,6 +22,31 @@ contains + module function conv1d(filters, kernel_size, activation) result(res) + integer, intent(in) :: filters + integer, intent(in) :: kernel_size + class(activation_function), intent(in), optional :: activation + type(layer) :: res + + class(activation_function), allocatable :: activation_tmp + + res % name = 'conv1d' + + if (present(activation)) then + allocate(activation_tmp, source=activation) + else + allocate(activation_tmp, source=relu()) + end if + + res % activation = activation_tmp % get_name() + + allocate( & + res % p, & + source=conv1d_layer(filters, kernel_size, activation_tmp) & + ) + + end function conv1d + module function conv2d(filters, kernel_size, activation) result(res) integer, intent(in) :: filters integer, intent(in) :: kernel_size diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index 49f5a050..dbf825a8 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -1,6 +1,7 @@ submodule(nf_layer) nf_layer_submodule use iso_fortran_env, only: stderr => error_unit + use nf_conv1d_layer, only: conv1d_layer use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer use nf_dropout_layer, only: dropout_layer @@ -49,12 +50,18 @@ pure module subroutine backward_1d(self, previous, gradient) type is(flatten_layer) - ! Upstream layers permitted: input2d, input3d, conv2d, maxpool2d + ! Upstream layers permitted: input2d, input3d, conv2d, locally_connected_1d, maxpool1d, maxpool2d select type(prev_layer => previous % p) type is(input2d_layer) call this_layer % backward(prev_layer % output, gradient) + type is(locally_connected_1d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(maxpool1d_layer) + call this_layer % backward(prev_layer % output, gradient) type is(input3d_layer) call this_layer % backward(prev_layer % output, gradient) + type is(conv1d_layer) + call this_layer % backward(prev_layer % output, gradient) type is(conv2d_layer) call this_layer % backward(prev_layer % output, gradient) type is(maxpool2d_layer) @@ -107,6 +114,19 @@ pure module subroutine backward_2d(self, previous, gradient) select type(this_layer => self % p) + type is(conv1d_layer) + + select type(prev_layer => previous % p) + type is(maxpool1d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(reshape2d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(locally_connected_1d_layer) + call this_layer % backward(prev_layer % output, gradient) + type is(conv1d_layer) + call this_layer % backward(prev_layer % output, gradient) + end select + type is(locally_connected_1d_layer) select type(prev_layer => previous % p) @@ -116,6 +136,8 @@ pure module subroutine backward_2d(self, previous, gradient) call this_layer % backward(prev_layer % output, gradient) type is(locally_connected_1d_layer) call this_layer % backward(prev_layer % output, gradient) + type is(conv1d_layer) + call this_layer % backward(prev_layer % output, gradient) end select type is(maxpool1d_layer) @@ -127,6 +149,8 @@ pure module subroutine backward_2d(self, previous, gradient) call this_layer % backward(prev_layer % output, gradient) type is(locally_connected_1d_layer) call this_layer % backward(prev_layer % output, gradient) + type is(conv1d_layer) + call this_layer % backward(prev_layer % output, gradient) end select type is(reshape2d_layer) @@ -254,6 +278,24 @@ module subroutine forward(self, input) call this_layer % forward(prev_layer % output) type is(reshape2d_layer) call this_layer % forward(prev_layer % output) + type is(conv1d_layer) + call this_layer % forward(prev_layer % output) + end select + + type is(conv1d_layer) + + ! Upstream layers permitted: input2d, locally_connected_1d, maxpool1d, reshape2d + select type(prev_layer => input % p) + type is(input2d_layer) + call this_layer % forward(prev_layer % output) + type is(locally_connected_1d_layer) + call this_layer % forward(prev_layer % output) + type is(maxpool1d_layer) + call this_layer % forward(prev_layer % output) + type is(reshape2d_layer) + call this_layer % forward(prev_layer % output) + type is(conv1d_layer) + call this_layer % forward(prev_layer % output) end select type is(maxpool1d_layer) @@ -268,6 +310,8 @@ module subroutine forward(self, input) call this_layer % forward(prev_layer % output) type is(reshape2d_layer) call this_layer % forward(prev_layer % output) + type is(conv1d_layer) + call this_layer % forward(prev_layer % output) end select type is(maxpool2d_layer) @@ -286,16 +330,24 @@ module subroutine forward(self, input) type is(flatten_layer) - ! Upstream layers permitted: input2d, input3d, conv2d, maxpool2d, reshape3d + ! Upstream layers permitted: input2d, input3d, conv2d, maxpool1d, maxpool2d, reshape2d, reshape3d, locally_connected_2d select type(prev_layer => input % p) type is(input2d_layer) call this_layer % forward(prev_layer % output) type is(input3d_layer) call this_layer % forward(prev_layer % output) + type is(conv1d_layer) + call this_layer % forward(prev_layer % output) type is(conv2d_layer) call this_layer % forward(prev_layer % output) + type is(locally_connected_1d_layer) + call this_layer % forward(prev_layer % output) + type is(maxpool1d_layer) + call this_layer % forward(prev_layer % output) type is(maxpool2d_layer) call this_layer % forward(prev_layer % output) + type is(reshape2d_layer) + call this_layer % forward(prev_layer % output) type is(reshape3d_layer) call this_layer % forward(prev_layer % output) type is(linear2d_layer) @@ -383,6 +435,8 @@ pure module subroutine get_output_2d(self, output) allocate(output, source=this_layer % output) type is(locally_connected_1d_layer) allocate(output, source=this_layer % output) + type is(conv1d_layer) + allocate(output, source=this_layer % output) type is(reshape2d_layer) allocate(output, source=this_layer % output) type is(linear2d_layer) @@ -435,6 +489,8 @@ impure elemental module subroutine init(self, input) ! The shape of conv2d, dropout, flatten, linear2d, maxpool2d, or ! self_attention layers is not known until we receive an input layer. select type(this_layer => self % p) + type is(conv1d_layer) + self % layer_shape = shape(this_layer % output) type is(conv2d_layer) self % layer_shape = shape(this_layer % output) type is(dropout_layer) @@ -495,6 +551,8 @@ elemental module function get_num_params(self) result(num_params) num_params = this_layer % get_num_params() type is (dropout_layer) num_params = 0 + type is (conv1d_layer) + num_params = this_layer % get_num_params() type is (conv2d_layer) num_params = this_layer % get_num_params() type is (locally_connected_1d_layer) @@ -534,6 +592,8 @@ module function get_params(self) result(params) params = this_layer % get_params() type is (dropout_layer) ! No parameters to get. + type is (conv1d_layer) + params = this_layer % get_params() type is (conv2d_layer) params = this_layer % get_params() type is (locally_connected_1d_layer) @@ -573,6 +633,8 @@ module function get_gradients(self) result(gradients) gradients = this_layer % get_gradients() type is (dropout_layer) ! No gradients to get. + type is (conv1d_layer) + gradients = this_layer % get_gradients() type is (conv2d_layer) gradients = this_layer % get_gradients() type is (locally_connected_1d_layer) @@ -639,6 +701,9 @@ module subroutine set_params(self, params) ! No parameters to set. write(stderr, '(a)') 'Warning: calling set_params() ' & // 'on a zero-parameter layer; nothing to do.' + + type is (conv1d_layer) + call this_layer % set_params(params) type is (conv2d_layer) call this_layer % set_params(params) diff --git a/src/nf/nf_maxpool1d_layer_submodule.f90 b/src/nf/nf_maxpool1d_layer_submodule.f90 index 9a0b081d..336264f7 100644 --- a/src/nf/nf_maxpool1d_layer_submodule.f90 +++ b/src/nf/nf_maxpool1d_layer_submodule.f90 @@ -1,4 +1,5 @@ submodule(nf_maxpool1d_layer) nf_maxpool1d_layer_submodule + implicit none contains @@ -8,9 +9,8 @@ pure module function maxpool1d_layer_cons(pool_size, stride) result(res) integer, intent(in) :: pool_size integer, intent(in) :: stride type(maxpool1d_layer) :: res - res % pool_size = pool_size - res % stride = stride + res % stride = stride end function maxpool1d_layer_cons @@ -18,25 +18,20 @@ module subroutine init(self, input_shape) implicit none class(maxpool1d_layer), intent(in out) :: self integer, intent(in) :: input_shape(:) - ! input_shape is expected to be (channels, width) self % channels = input_shape(1) - ! The new width is the integer division of the input width by the stride. - self % width = input_shape(2) / self % stride + self % width = input_shape(2) / self % stride - ! Allocate storage for the index of the maximum element within each pooling region. allocate(self % maxloc(self % channels, self % width)) self % maxloc = 0 - ! Allocate the gradient array corresponding to the input dimensions. - allocate(self % gradient(input_shape(1), input_shape(2))) + allocate(self % gradient(input_shape(1),input_shape(2))) self % gradient = 0 - ! Allocate the output array (after pooling). allocate(self % output(self % channels, self % width)) self % output = 0 - end subroutine init + end subroutine init pure module subroutine forward(self, input) implicit none @@ -44,50 +39,55 @@ pure module subroutine forward(self, input) real, intent(in) :: input(:,:) integer :: input_width integer :: i, n - integer :: ii, iend + integer :: ii + integer :: iend integer :: iextent - integer :: max_index ! Temporary variable to hold the local index of the max - integer :: maxloc_temp(1) ! Temporary array to hold the result of maxloc + integer :: maxloc_x input_width = size(input, dim=2) - ! Ensure we only process complete pooling regions. + iextent = input_width - mod(input_width, self % stride) - ! Loop over the input with a step size equal to the stride and over all channels. - do concurrent (i = 1:iextent: self % stride, n = 1:self % channels) - ! Compute the index in the pooled (output) array. - ii = (i - 1) / self % stride + 1 - ! Determine the ending index of the current pooling region. - iend = min(i + self % pool_size - 1, input_width) - - ! Find the index (within the pooling window) of the maximum value. - maxloc_temp = maxloc(input(n, i:iend)) - max_index = maxloc_temp(1) + i - 1 ! Adjust to the index in the original input - - ! Store the location of the maximum value. - self % maxloc(n, ii) = max_index - ! Set the output as the maximum value from this pooling region. - self % output(n, ii) = input(n, max_index) - end do + ! Stride along the width of the input + stride_over_input: do concurrent(i = 1:iextent:self % stride) + + ! Index of the pooling layer + ii = i / self % stride + 1 + iend = i + self % pool_size - 1 + + maxpool_for_each_channel: do concurrent(n = 1:self % channels) + + ! Get and store the location of the maximum value + maxloc_x = maxloc(input(n, i:iend), dim=1) + self % maxloc(n,ii) = maxloc_x + i - 1 + + self % output(n,ii) = input(n, self % maxloc(n,ii)) + + end do maxpool_for_each_channel + + end do stride_over_input + end subroutine forward - pure module subroutine backward(self, input, gradient) implicit none class(maxpool1d_layer), intent(in out) :: self real, intent(in) :: input(:,:) real, intent(in) :: gradient(:,:) - integer :: channels, pooled_width + integer :: gradient_shape(2) + integer :: channels, width integer :: i, n - channels = size(gradient, dim=1) - pooled_width = size(gradient, dim=2) + gradient_shape = shape(gradient) + channels = gradient_shape(1) + width = gradient_shape(2) - ! The gradient for max-pooling is nonzero only at the input locations - ! that were the maxima during the forward pass. - do concurrent (n = 1:channels, i = 1:pooled_width) - self % gradient(n, self % maxloc(n, i)) = gradient(n, i) + ! The gradient of a max-pooling layer is assigned to the stored max locations + do concurrent(n = 1:channels, i = 1:width) + self % gradient(n, self % maxloc(n,i)) = gradient(n,i) end do + end subroutine backward -end submodule nf_maxpool1d_layer_submodule \ No newline at end of file + +end submodule nf_maxpool1d_layer_submodule diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 8e6c53a9..7eeb08c0 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -15,7 +15,7 @@ use nf_linear2d_layer, only: linear2d_layer use nf_self_attention_layer, only: self_attention_layer use nf_layer, only: layer - use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool1d, maxpool2d, reshape, reshape2d + use nf_layer_constructors, only: conv1d, conv2d, dense, flatten, input, maxpool1d, maxpool2d, reshape, reshape2d use nf_loss, only: quadratic use nf_optimizers, only: optimizer_base_type, sgd use nf_parallel, only: tile_indices diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4720628b..6f8e45fc 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -6,7 +6,10 @@ foreach(execid linear2d_layer parametric_activation dense_layer + conv1d_layer + conv1d_network conv2d_layer + locally_connected_1d_layer maxpool1d_layer maxpool2d_layer flatten_layer diff --git a/test/test_conv1d_layer.f90 b/test/test_conv1d_layer.f90 new file mode 100644 index 00000000..a5ec87a7 --- /dev/null +++ b/test/test_conv1d_layer.f90 @@ -0,0 +1,85 @@ +program test_conv1d_layer + + use iso_fortran_env, only: stderr => error_unit + use nf, only: conv1d, input, layer + use nf_input2d_layer, only: input2d_layer + + implicit none + + type(layer) :: conv1d_layer, input_layer + integer, parameter :: filters = 32, kernel_size=3 + real, allocatable :: sample_input(:,:), output(:,:) + real, parameter :: tolerance = 1e-7 + logical :: ok = .true. + + conv1d_layer = conv1d(filters, kernel_size) + + if (.not. conv1d_layer % name == 'conv1d') then + ok = .false. + write(stderr, '(a)') 'conv1d layer has its name set correctly.. failed' + end if + + if (conv1d_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'conv1d layer should not be marked as initialized yet.. failed' + end if + + if (.not. conv1d_layer % activation == 'relu') then + ok = .false. + write(stderr, '(a)') 'conv1d layer defaults to relu activation.. failed' + end if + + input_layer = input(3, 32) + call conv1d_layer % init(input_layer) + + if (.not. conv1d_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'conv1d layer should now be marked as initialized.. failed' + end if + + if (.not. all(conv1d_layer % input_layer_shape == [3, 32])) then + ok = .false. + write(stderr, '(a)') 'conv1d layer input layer shape should be correct.. failed' + end if + + if (.not. all(conv1d_layer % layer_shape == [filters, 30])) then + ok = .false. + write(stderr, '(a)') 'conv1d layer input layer shape should be correct.. failed' + end if + + ! Minimal conv1d layer: 1 channel, 3x3 pixel image; + allocate(sample_input(1, 3)) + sample_input = 0 + + ! Print the sample input array + print *, "Sample Input:" + print *, sample_input + + input_layer = input(1, 3) + conv1d_layer = conv1d(filters, kernel_size) + call conv1d_layer % init(input_layer) + + select type(this_layer => input_layer % p); type is(input2d_layer) + call this_layer % set(sample_input) + end select + + call conv1d_layer % forward(input_layer) + call conv1d_layer % get_output(output) + + ! Print the output array after the forward pass + print *, "Output:" + print *, output + + if (.not. all(abs(output) < tolerance)) then + ok = .false. + write(stderr, '(a)') 'conv1d layer with zero input and sigmoid function must forward to all 0.5.. failed' + end if + + if (ok) then + print '(a)', 'test_conv1d_layer: All tests passed.' + else + write(stderr, '(a)') 'test_conv1d_layer: One or more tests failed.' + stop 1 + end if + +end program test_conv1d_layer diff --git a/test/test_conv1d_network.f90 b/test/test_conv1d_network.f90 new file mode 100644 index 00000000..577a765a --- /dev/null +++ b/test/test_conv1d_network.f90 @@ -0,0 +1,153 @@ +program test_conv1d_network + + use iso_fortran_env, only: stderr => error_unit + use nf, only: conv1d, input, network, dense, sgd, maxpool1d + + implicit none + + type(network) :: net + real, allocatable :: sample_input(:,:), output(:,:) + logical :: ok = .true. + + ! 2-layer convolutional network + net = network([ & + input(3, 32), & + conv1d(filters=16, kernel_size=3), & + conv1d(filters=32, kernel_size=3) & + ]) + + if (.not. size(net % layers) == 3) then + write(stderr, '(a)') 'conv1d network should have 3 layers.. failed' + ok = .false. + end if + + ! Test for output shape + allocate(sample_input(3, 32)) + sample_input = 0 + + call net % forward(sample_input) + call net % layers(2) % get_output(output) + + if (.not. all(shape(output) == [32, 28])) then + write(stderr, '(a)') 'conv1d network output should have correct shape.. failed' + ok = .false. + end if + + deallocate(sample_input, output) + + training1: block + + type(network) :: cnn + real :: y(1) + real :: tolerance = 1e-4 + integer :: n + integer, parameter :: num_iterations = 1000 + + ! Test training of a minimal constant mapping + allocate(sample_input(1, 5)) + call random_number(sample_input) + + cnn = network([ & + input(1, 5), & + conv1d(filters=1, kernel_size=3), & + conv1d(filters=1, kernel_size=3), & + dense(1) & + ]) + + y = [0.1234567] + + do n = 1, num_iterations + call cnn % forward(sample_input) + call cnn % backward(y) + call cnn % update(optimizer=sgd(learning_rate=1.)) + if (all(abs(cnn % predict(sample_input) - y) < tolerance)) exit + end do + + if (.not. n <= num_iterations) then + write(stderr, '(a)') & + 'convolutional network 1 should converge in simple training.. failed' + ok = .false. + end if + + end block training1 + + training2: block + + type(network) :: cnn + real :: x(1, 8) + real :: y(1) + real :: tolerance = 1e-4 + integer :: n + integer, parameter :: num_iterations = 1000 + + call random_number(x) + y = [0.1234567] + + cnn = network([ & + input(1, 8), & + conv1d(filters=1, kernel_size=3), & + maxpool1d(pool_size=2), & + conv1d(filters=1, kernel_size=3), & + dense(1) & + ]) + + do n = 1, num_iterations + call cnn % forward(x) + call cnn % backward(y) + call cnn % update(optimizer=sgd(learning_rate=1.)) + if (all(abs(cnn % predict(x) - y) < tolerance)) exit + end do + + if (.not. n <= num_iterations) then + write(stderr, '(a)') & + 'convolutional network 2 should converge in simple training.. failed' + ok = .false. + end if + + end block training2 + + training3: block + + type(network) :: cnn + real :: x(1, 12) + real :: y(9) + real :: tolerance = 1e-4 + integer :: n + integer, parameter :: num_iterations = 5000 + + call random_number(x) + y = [0.12345, 0.23456, 0.34567, 0.45678, 0.56789, 0.67890, 0.78901, 0.89012, 0.90123] + + cnn = network([ & + input(1, 12), & + conv1d(filters=1, kernel_size=3), & ! 1x12 input, 1x10 output + maxpool1d(pool_size=2), & ! 1x10 input, 1x5 output + conv1d(filters=1, kernel_size=3), & ! 1x5 input, 1x3 output + dense(9) & ! 9 outputs + ]) + + do n = 1, num_iterations + call cnn % forward(x) + call cnn % backward(y) + call cnn % update(optimizer=sgd(learning_rate=1.)) + if (all(abs(cnn % predict(x) - y) < tolerance)) exit + end do + + if (.not. n <= num_iterations) then + write(stderr, '(a)') & + 'convolutional network 3 should converge in simple training.. failed' + ok = .false. + end if + + end block training3 + + + if (ok) then + print '(a)', 'test_conv1d_network: All tests passed.' + else + write(stderr, '(a)') 'test_conv1d_network: One or more tests failed.' + stop 1 + end if + + end program test_conv1d_network + \ No newline at end of file diff --git a/test/test_locally_connected_1d_layer.f90 b/test/test_locally_connected_1d_layer.f90 new file mode 100644 index 00000000..5ce48b43 --- /dev/null +++ b/test/test_locally_connected_1d_layer.f90 @@ -0,0 +1,85 @@ +program test_locally_connected_1d_layer + + use iso_fortran_env, only: stderr => error_unit + use nf, only: locally_connected_1d, input, layer + use nf_input2d_layer, only: input2d_layer + + implicit none + + type(layer) :: locally_connected_1d_layer, input_layer + integer, parameter :: filters = 32, kernel_size=3 + real, allocatable :: sample_input(:,:), output(:,:) + real, parameter :: tolerance = 1e-7 + logical :: ok = .true. + + locally_connected_1d_layer = locally_connected_1d(filters, kernel_size) + + if (.not. locally_connected_1d_layer % name == 'locally_connected_1d') then + ok = .false. + write(stderr, '(a)') 'locally_connected_1d layer has its name set correctly.. failed' + end if + + if (locally_connected_1d_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'locally_connected_1d layer should not be marked as initialized yet.. failed' + end if + + if (.not. locally_connected_1d_layer % activation == 'relu') then + ok = .false. + write(stderr, '(a)') 'locally_connected_1d layer defaults to relu activation.. failed' + end if + + input_layer = input(3, 32) + call locally_connected_1d_layer % init(input_layer) + + if (.not. locally_connected_1d_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'locally_connected_1d layer should now be marked as initialized.. failed' + end if + + if (.not. all(locally_connected_1d_layer % input_layer_shape == [3, 32])) then + ok = .false. + write(stderr, '(a)') 'locally_connected_1d layer input layer shape should be correct.. failed' + end if + + if (.not. all(locally_connected_1d_layer % layer_shape == [filters, 30])) then + ok = .false. + write(stderr, '(a)') 'locally_connected_1d layer input layer shape should be correct.. failed' + end if + + ! Minimal locally_connected_1d layer: 1 channel, 3x3 pixel image; + allocate(sample_input(1, 3)) + sample_input = 0 + + ! Print the sample input array + print *, "Sample Input:" + print *, sample_input + + input_layer = input(1, 3) + locally_connected_1d_layer = locally_connected_1d(filters, kernel_size) + call locally_connected_1d_layer % init(input_layer) + + select type(this_layer => input_layer % p); type is(input2d_layer) + call this_layer % set(sample_input) + end select + + call locally_connected_1d_layer % forward(input_layer) + call locally_connected_1d_layer % get_output(output) + + ! Print the output array after the forward pass + print *, "Output:" + print *, output + + if (.not. all(abs(output) < tolerance)) then + ok = .false. + write(stderr, '(a)') 'locally_connected_1d layer with zero input and sigmoid function must forward to all 0.5.. failed' + end if + + if (ok) then + print '(a)', 'test_locally_connected_1d_layer: All tests passed.' + else + write(stderr, '(a)') 'test_locally_connected_1d_layer: One or more tests failed.' + stop 1 + end if + +end program test_locally_connected_1d_layer diff --git a/test/test_maxpool1d_layer.f90 b/test/test_maxpool1d_layer.f90 index a5691272..023a2c33 100644 --- a/test/test_maxpool1d_layer.f90 +++ b/test/test_maxpool1d_layer.f90 @@ -1,100 +1,95 @@ program test_maxpool1d_layer - use iso_fortran_env, only: stderr => error_unit - use nf, only: maxpool1d, input, layer - use nf_input2d_layer, only: input2d_layer - use nf_maxpool1d_layer, only: maxpool1d_layer - - implicit none - - type(layer) :: maxpool_layer, input_layer - integer, parameter :: pool_size = 2, stride = 2 - integer, parameter :: channels = 3, length = 32 - integer, parameter :: input_shape(2) = [channels, length] - integer, parameter :: output_shape(2) = [channels, length / 2] - real, allocatable :: sample_input(:,:), output(:,:), gradient(:,:) - integer :: i - logical :: ok = .true., gradient_ok = .true. - - maxpool_layer = maxpool1d(pool_size) - - if (.not. maxpool_layer % name == 'maxpool1d') then - ok = .false. - write(stderr, '(a)') 'maxpool1d layer has its name set correctly.. failed' - end if - - if (maxpool_layer % initialized) then - ok = .false. - write(stderr, '(a)') 'maxpool1d layer should not be marked as initialized yet.. failed' - end if - - input_layer = input(channels, length) - call maxpool_layer % init(input_layer) - - if (.not. maxpool_layer % initialized) then - ok = .false. - write(stderr, '(a)') 'maxpool1d layer should now be marked as initialized.. failed' - end if - - if (.not. all(maxpool_layer % input_layer_shape == input_shape)) then - ok = .false. - write(stderr, '(a)') 'maxpool1d layer input layer shape should be correct.. failed' - end if - - if (.not. all(maxpool_layer % layer_shape == output_shape)) then + use iso_fortran_env, only: stderr => error_unit + use nf, only: maxpool1d, input, layer + use nf_input2d_layer, only: input2d_layer + use nf_maxpool1d_layer, only: maxpool1d_layer + + implicit none + + type(layer) :: maxpool_layer, input_layer + integer, parameter :: pool_size = 2, stride = 2 + integer, parameter :: channels = 3, length = 32 + integer, parameter :: input_shape(2) = [channels, length] + integer, parameter :: output_shape(2) = [channels, length / 2] + real, allocatable :: sample_input(:,:), output(:,:), gradient(:,:) + integer :: i + logical :: ok = .true., gradient_ok = .true. + + maxpool_layer = maxpool1d(pool_size) + + if (.not. maxpool_layer % name == 'maxpool1d') then + ok = .false. + write(stderr, '(a)') 'maxpool1d layer has its name set correctly.. failed' + end if + + if (maxpool_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'maxpool1d layer should not be marked as initialized yet.. failed' + end if + + input_layer = input(channels, length) + call maxpool_layer % init(input_layer) + + if (.not. maxpool_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'maxpool1d layer should now be marked as initialized.. failed' + end if + + if (.not. all(maxpool_layer % input_layer_shape == input_shape)) then + ok = .false. + write(stderr, '(a)') 'maxpool1d layer input layer shape should be correct.. failed' + end if + + if (.not. all(maxpool_layer % layer_shape == output_shape)) then + ok = .false. + write(stderr, '(a)') 'maxpool1d layer output layer shape should be correct.. failed' + end if + + ! Allocate and initialize sample input data + allocate(sample_input(channels, length)) + do concurrent(i = 1:length) + sample_input(:,i) = i + end do + + select type(this_layer => input_layer % p); type is(input2d_layer) + call this_layer % set(sample_input) + end select + + call maxpool_layer % forward(input_layer) + call maxpool_layer % get_output(output) + + do i = 1, length / 2 + if (.not. all(output(:,i) == stride * i)) then ok = .false. - write(stderr, '(a)') 'maxpool1d layer output layer shape should be correct.. failed' + write(stderr, '(a)') 'maxpool1d layer forward pass correctly propagates the max value.. failed' end if - - ! Allocate and initialize sample input data - allocate(sample_input(channels, length)) - do concurrent(i = 1:length) - sample_input(:,i) = i - end do - - select type(this_layer => input_layer % p); type is(input2d_layer) - call this_layer % set(sample_input) - end select - - call maxpool_layer % forward(input_layer) - call maxpool_layer % get_output(output) - - do i = 1, length / 2 - ! Since input is i, maxpool1d output must be stride*i - if (.not. all(output(:,i) == stride * i)) then - ok = .false. - write(stderr, '(a)') 'maxpool1d layer forward pass correctly propagates the max value.. failed' + end do + + ! Test the backward pass + allocate(gradient, source=output) + call maxpool_layer % backward(input_layer, gradient) + + select type(this_layer => maxpool_layer % p); type is(maxpool1d_layer) + do i = 1, length + if (mod(i,2) == 0) then + if (.not. all(sample_input(:,i) == this_layer % gradient(:,i))) gradient_ok = .false. + else + if (.not. all(this_layer % gradient(:,i) == 0)) gradient_ok = .false. end if end do - - ! Test the backward pass - ! Allocate and initialize the downstream gradient field - allocate(gradient, source=output) - - ! Make a backward pass - call maxpool_layer % backward(input_layer, gradient) - - select type(this_layer => maxpool_layer % p); type is(maxpool1d_layer) - do i = 1, length - if (mod(i,2) == 0) then - if (.not. all(sample_input(:,i) == this_layer % gradient(:,i))) gradient_ok = .false. - else - if (.not. all(this_layer % gradient(:,i) == 0)) gradient_ok = .false. - end if - end do - end select - - if (.not. gradient_ok) then - ok = .false. - write(stderr, '(a)') 'maxpool1d layer backward pass produces the correct dL/dx.. failed' - end if - - if (ok) then - print '(a)', 'test_maxpool1d_layer: All tests passed.' - else - write(stderr, '(a)') 'test_maxpool1d_layer: One or more tests failed.' - stop 1 - end if - - end program test_maxpool1d_layer - \ No newline at end of file + end select + + if (.not. gradient_ok) then + ok = .false. + write(stderr, '(a)') 'maxpool1d layer backward pass produces the correct dL/dx.. failed' + end if + + if (ok) then + print '(a)', 'test_maxpool1d_layer: All tests passed.' + else + write(stderr, '(a)') 'test_maxpool1d_layer: One or more tests failed.' + stop 1 + end if + +end program test_maxpool1d_layer From a08fba024f86e1f3967f7236e701fa4bc109c14e Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Sun, 23 Feb 2025 14:08:32 +0100 Subject: [PATCH 07/26] Bug fixes; now everything works --- example/cnn_mnist_1d.f90 | 10 +++++----- src/nf/nf_layer_submodule.f90 | 6 ++++++ src/nf/nf_network_submodule.f90 | 14 +++++++++++--- test/test_conv1d_layer.f90 | 8 -------- test/test_conv1d_network.f90 | 5 +++-- test/test_locally_connected_1d_layer.f90 | 7 ------- 6 files changed, 25 insertions(+), 25 deletions(-) diff --git a/example/cnn_mnist_1d.f90 b/example/cnn_mnist_1d.f90 index 4157510a..dba2ab2d 100644 --- a/example/cnn_mnist_1d.f90 +++ b/example/cnn_mnist_1d.f90 @@ -1,7 +1,7 @@ program cnn_mnist_1d use nf, only: network, sgd, & - input, conv2d, maxpool1d, maxpool2d, flatten, dense, reshape, reshape2d, locally_connected_1d, & + input, conv1d, conv2d, maxpool1d, maxpool2d, flatten, dense, reshape, reshape2d, locally_connected_1d, & load_mnist, label_digits, softmax, relu implicit none @@ -12,7 +12,7 @@ program cnn_mnist_1d real, allocatable :: validation_images(:,:), validation_labels(:) real, allocatable :: testing_images(:,:), testing_labels(:) integer :: n - integer, parameter :: num_epochs = 10 + integer, parameter :: num_epochs = 25 call load_mnist(training_images, training_labels, & validation_images, validation_labels, & @@ -21,9 +21,9 @@ program cnn_mnist_1d net = network([ & input(784), & reshape2d([28,28]), & - locally_connected_1d(filters=8, kernel_size=3, activation=relu()), & + conv1d(filters=8, kernel_size=3, activation=relu()), & maxpool1d(pool_size=2), & - locally_connected_1d(filters=16, kernel_size=3, activation=relu()), & + conv1d(filters=16, kernel_size=3, activation=relu()), & maxpool1d(pool_size=2), & dense(10, activation=softmax()) & ]) @@ -37,7 +37,7 @@ program cnn_mnist_1d label_digits(training_labels), & batch_size=16, & epochs=1, & - optimizer=sgd(learning_rate=0.003) & + optimizer=sgd(learning_rate=0.005) & ) print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( & diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index dbf825a8..ab7c1c9e 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -121,6 +121,8 @@ pure module subroutine backward_2d(self, previous, gradient) call this_layer % backward(prev_layer % output, gradient) type is(reshape2d_layer) call this_layer % backward(prev_layer % output, gradient) + type is(input2d_layer) + call this_layer % backward(prev_layer % output, gradient) type is(locally_connected_1d_layer) call this_layer % backward(prev_layer % output, gradient) type is(conv1d_layer) @@ -134,6 +136,8 @@ pure module subroutine backward_2d(self, previous, gradient) call this_layer % backward(prev_layer % output, gradient) type is(reshape2d_layer) call this_layer % backward(prev_layer % output, gradient) + type is(input2d_layer) + call this_layer % backward(prev_layer % output, gradient) type is(locally_connected_1d_layer) call this_layer % backward(prev_layer % output, gradient) type is(conv1d_layer) @@ -149,6 +153,8 @@ pure module subroutine backward_2d(self, previous, gradient) call this_layer % backward(prev_layer % output, gradient) type is(locally_connected_1d_layer) call this_layer % backward(prev_layer % output, gradient) + type is(input2d_layer) + call this_layer % backward(prev_layer % output, gradient) type is(conv1d_layer) call this_layer % backward(prev_layer % output, gradient) end select diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 7eeb08c0..3837217c 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -1,5 +1,6 @@ submodule(nf_network) nf_network_submodule + use nf_conv1d_layer, only: conv1d_layer use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer use nf_dropout_layer, only: dropout_layer @@ -76,9 +77,9 @@ module function network_from_layers(layers) result(res) type is(conv2d_layer) res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] n = n + 1 - !type is(locally_connected_1d_layer) - !res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] - !n = n + 1 + type is(locally_connected_1d_layer) + res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] + n = n + 1 type is(maxpool2d_layer) res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] n = n + 1 @@ -88,6 +89,9 @@ module function network_from_layers(layers) result(res) type is(maxpool1d_layer) res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] n = n + 1 + type is(conv1d_layer) + res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] + n = n + 1 !type is(reshape2d_layer) ! res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] ! n = n + 1 @@ -179,6 +183,10 @@ module subroutine backward(self, output, loss) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) type is(reshape2d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(conv1d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(locally_connected_1d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) end select end if diff --git a/test/test_conv1d_layer.f90 b/test/test_conv1d_layer.f90 index a5ec87a7..81d03c1f 100644 --- a/test/test_conv1d_layer.f90 +++ b/test/test_conv1d_layer.f90 @@ -51,10 +51,6 @@ program test_conv1d_layer allocate(sample_input(1, 3)) sample_input = 0 - ! Print the sample input array - print *, "Sample Input:" - print *, sample_input - input_layer = input(1, 3) conv1d_layer = conv1d(filters, kernel_size) call conv1d_layer % init(input_layer) @@ -66,10 +62,6 @@ program test_conv1d_layer call conv1d_layer % forward(input_layer) call conv1d_layer % get_output(output) - ! Print the output array after the forward pass - print *, "Output:" - print *, output - if (.not. all(abs(output) < tolerance)) then ok = .false. write(stderr, '(a)') 'conv1d layer with zero input and sigmoid function must forward to all 0.5.. failed' diff --git a/test/test_conv1d_network.f90 b/test/test_conv1d_network.f90 index 577a765a..df0d52d0 100644 --- a/test/test_conv1d_network.f90 +++ b/test/test_conv1d_network.f90 @@ -9,7 +9,7 @@ program test_conv1d_network real, allocatable :: sample_input(:,:), output(:,:) logical :: ok = .true. - ! 2-layer convolutional network + ! 3-layer convolutional network net = network([ & input(3, 32), & conv1d(filters=16, kernel_size=3), & @@ -26,7 +26,7 @@ program test_conv1d_network sample_input = 0 call net % forward(sample_input) - call net % layers(2) % get_output(output) + call net % layers(3) % get_output(output) if (.not. all(shape(output) == [32, 28])) then write(stderr, '(a)') 'conv1d network output should have correct shape.. failed' @@ -64,6 +64,7 @@ program test_conv1d_network end do if (.not. n <= num_iterations) then + write(stderr, '(a)') & 'convolutional network 1 should converge in simple training.. failed' ok = .false. diff --git a/test/test_locally_connected_1d_layer.f90 b/test/test_locally_connected_1d_layer.f90 index 5ce48b43..50489128 100644 --- a/test/test_locally_connected_1d_layer.f90 +++ b/test/test_locally_connected_1d_layer.f90 @@ -51,10 +51,6 @@ program test_locally_connected_1d_layer allocate(sample_input(1, 3)) sample_input = 0 - ! Print the sample input array - print *, "Sample Input:" - print *, sample_input - input_layer = input(1, 3) locally_connected_1d_layer = locally_connected_1d(filters, kernel_size) call locally_connected_1d_layer % init(input_layer) @@ -66,9 +62,6 @@ program test_locally_connected_1d_layer call locally_connected_1d_layer % forward(input_layer) call locally_connected_1d_layer % get_output(output) - ! Print the output array after the forward pass - print *, "Output:" - print *, output if (.not. all(abs(output) < tolerance)) then ok = .false. From 52f958f6ee4870fdaa1431aab1471659160ea3c3 Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Sun, 23 Feb 2025 14:20:02 +0100 Subject: [PATCH 08/26] Updated the comments --- src/nf/nf_conv1d_layer.f90 | 4 ++-- src/nf/nf_layer_constructors.f90 | 19 +++++++++---------- src/nf/nf_layer_submodule.f90 | 4 ++-- test/test_conv1d_network.f90 | 4 ++-- 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/nf/nf_conv1d_layer.f90 b/src/nf/nf_conv1d_layer.f90 index 794d441e..c39b11fc 100644 --- a/src/nf/nf_conv1d_layer.f90 +++ b/src/nf/nf_conv1d_layer.f90 @@ -17,8 +17,8 @@ module nf_conv1d_layer integer :: filters real, allocatable :: biases(:) ! size(filters) - real, allocatable :: kernel(:,:,:) ! filters x channels x window x window - real, allocatable :: output(:,:) ! filters x output_width * output_height + real, allocatable :: kernel(:,:,:) ! filters x channels x window + real, allocatable :: output(:,:) ! filters x output_width real, allocatable :: z(:,:) ! kernel .dot. input + bias real, allocatable :: dw(:,:,:) ! weight (kernel) gradients diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index 3dc13c50..acd35d8c 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -154,22 +154,21 @@ module function flatten() result(res) end function flatten module function conv1d(filters, kernel_size, activation) result(res) - !! CHANGE THE COMMENTS - !! 2-d convolutional layer constructor. + !! 1-d convolutional layer constructor. !! - !! This layer is for building 2-d convolutional network. - !! Although the established convention is to call these layers 2-d, - !! the shape of the data is actuall 3-d: image width, image height, + !! This layer is for building 1-d convolutional network. + !! Although the established convention is to call these layers 1-d, + !! the shape of the data is actually 2-d: image width !! and the number of channels. - !! A conv2d layer must not be the first layer in the network. + !! A conv1d layer must not be the first layer in the network. !! !! Example: !! !! ``` - !! use nf, only :: conv2d, layer - !! type(layer) :: conv2d_layer - !! conv2d_layer = dense(filters=32, kernel_size=3) - !! conv2d_layer = dense(filters=32, kernel_size=3, activation='relu') + !! use nf, only :: conv1d, layer + !! type(layer) :: conv1d_layer + !! conv1d_layer = dense(filters=32, kernel_size=3) + !! conv1d_layer = dense(filters=32, kernel_size=3, activation='relu') !! ``` integer, intent(in) :: filters !! Number of filters in the output of the layer diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index ab7c1c9e..23f487e5 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -27,7 +27,7 @@ pure module subroutine backward_1d(self, previous, gradient) real, intent(in) :: gradient(:) ! Backward pass from a 1-d layer downstream currently implemented - ! only for dense and flatten layers + ! only for dense, dropout and flatten layers select type(this_layer => self % p) type is(dense_layer) @@ -50,7 +50,7 @@ pure module subroutine backward_1d(self, previous, gradient) type is(flatten_layer) - ! Upstream layers permitted: input2d, input3d, conv2d, locally_connected_1d, maxpool1d, maxpool2d + ! Upstream layers permitted: input2d, input3d, conv1d, conv2d, locally_connected_1d, maxpool1d, maxpool2d select type(prev_layer => previous % p) type is(input2d_layer) call this_layer % backward(prev_layer % output, gradient) diff --git a/test/test_conv1d_network.f90 b/test/test_conv1d_network.f90 index df0d52d0..cefcf327 100644 --- a/test/test_conv1d_network.f90 +++ b/test/test_conv1d_network.f90 @@ -41,7 +41,7 @@ program test_conv1d_network real :: y(1) real :: tolerance = 1e-4 integer :: n - integer, parameter :: num_iterations = 1000 + integer, parameter :: num_iterations = 1500 ! Test training of a minimal constant mapping allocate(sample_input(1, 5)) @@ -79,7 +79,7 @@ program test_conv1d_network real :: y(1) real :: tolerance = 1e-4 integer :: n - integer, parameter :: num_iterations = 1000 + integer, parameter :: num_iterations = 1500 call random_number(x) y = [0.1234567] From b64038a4bcf0259196261a72280df3c1552a8174 Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Sun, 23 Feb 2025 17:43:38 +0100 Subject: [PATCH 09/26] Implemented locally connected 1d --- example/cnn_mnist_1d.f90 | 4 +- src/nf/nf_locally_connected_1d.f90 | 8 +-- src/nf/nf_locally_connected_1d_submodule.f90 | 73 +++++--------------- 3 files changed, 25 insertions(+), 60 deletions(-) diff --git a/example/cnn_mnist_1d.f90 b/example/cnn_mnist_1d.f90 index dba2ab2d..3a80a81a 100644 --- a/example/cnn_mnist_1d.f90 +++ b/example/cnn_mnist_1d.f90 @@ -21,9 +21,9 @@ program cnn_mnist_1d net = network([ & input(784), & reshape2d([28,28]), & - conv1d(filters=8, kernel_size=3, activation=relu()), & + locally_connected_1d(filters=8, kernel_size=3, activation=relu()), & maxpool1d(pool_size=2), & - conv1d(filters=16, kernel_size=3, activation=relu()), & + locally_connected_1d(filters=16, kernel_size=3, activation=relu()), & maxpool1d(pool_size=2), & dense(10, activation=softmax()) & ]) diff --git a/src/nf/nf_locally_connected_1d.f90 b/src/nf/nf_locally_connected_1d.f90 index 739df749..1dc3b4a1 100644 --- a/src/nf/nf_locally_connected_1d.f90 +++ b/src/nf/nf_locally_connected_1d.f90 @@ -16,13 +16,13 @@ module nf_locally_connected_1d_layer integer :: kernel_size integer :: filters - real, allocatable :: biases(:) ! size(filters) - real, allocatable :: kernel(:,:,:) ! filters x channels x window x window + real, allocatable :: biases(:,:) ! size(filters) + real, allocatable :: kernel(:,:,:,:) ! filters x channels x window x window real, allocatable :: output(:,:) ! filters x output_width * output_height real, allocatable :: z(:,:) ! kernel .dot. input + bias - real, allocatable :: dw(:,:,:) ! weight (kernel) gradients - real, allocatable :: db(:) ! bias gradients + real, allocatable :: dw(:,:,:,:) ! weight (kernel) gradients + real, allocatable :: db(:,:) ! bias gradients real, allocatable :: gradient(:,:) class(activation_function), allocatable :: activation diff --git a/src/nf/nf_locally_connected_1d_submodule.f90 b/src/nf/nf_locally_connected_1d_submodule.f90 index 2a65b7f6..e3903b54 100644 --- a/src/nf/nf_locally_connected_1d_submodule.f90 +++ b/src/nf/nf_locally_connected_1d_submodule.f90 @@ -17,7 +17,7 @@ module function locally_connected_1d_layer_cons(filters, kernel_size, activation res % kernel_size = kernel_size res % filters = filters res % activation_name = activation % get_name() - allocate( res % activation, source = activation ) + allocate(res % activation, source = activation) end function locally_connected_1d_layer_cons module subroutine init(self, input_shape) @@ -28,16 +28,14 @@ module subroutine init(self, input_shape) self % channels = input_shape(1) self % width = input_shape(2) - self % kernel_size + 1 - ! Output of shape: filters x width allocate(self % output(self % filters, self % width)) self % output = 0 - ! Kernel of shape: filters x channels x kernel_size - allocate(self % kernel(self % filters, self % channels, self % kernel_size)) + allocate(self % kernel(self % filters, self % width, self % channels, self % kernel_size)) call random_normal(self % kernel) self % kernel = self % kernel / real(self % kernel_size**2) - allocate(self % biases(self % filters)) + allocate(self % biases(self % filters, self % width)) self % biases = 0 allocate(self % z, mold=self % output) @@ -51,7 +49,6 @@ module subroutine init(self, input_shape) allocate(self % db, mold=self % biases) self % db = 0 - end subroutine init pure module subroutine forward(self, input) @@ -60,113 +57,81 @@ pure module subroutine forward(self, input) real, intent(in) :: input(:,:) integer :: input_channels, input_width integer :: j, n - integer :: iws, iwe, half_window + integer :: iws, iwe input_channels = size(input, dim=1) input_width = size(input, dim=2) - half_window = self % kernel_size / 2 - ! Loop over output positions. do j = 1, self % width - ! Compute the input window corresponding to output index j. - ! In forward: center index = j + half_window, so window = indices j to j+kernel_size-1. iws = j iwe = j + self % kernel_size - 1 - - ! For each filter, compute the convolution (inner product over channels and kernel width). do concurrent (n = 1:self % filters) - self % z(n, j) = sum(self % kernel(n, :, :) * input(:, iws:iwe)) + self % z(n, j) = sum(self % kernel(n, j, :, :) * input(:, iws:iwe)) + self % biases(n, j) end do - - ! Add the bias for each filter. - self % z(:, j) = self % z(:, j) + self % biases end do - - ! Apply the activation function. self % output = self % activation % eval(self % z) end subroutine forward pure module subroutine backward(self, input, gradient) implicit none class(locally_connected_1d_layer), intent(in out) :: self - ! 'input' has shape: (channels, input_width) - ! 'gradient' (dL/dy) has shape: (filters, output_width) real, intent(in) :: input(:,:) real, intent(in) :: gradient(:,:) - integer :: input_channels, input_width, output_width integer :: j, n, k - integer :: iws, iwe, half_window - real :: gdz_val + integer :: iws, iwe + real :: gdz(self % filters, self % width) + real :: db_local(self % filters, self % width) + real :: dw_local(self % filters, self % width, self % channels, self % kernel_size) - ! Local arrays to accumulate gradients. - real :: gdz(self % filters, self % width) ! local gradient (dL/dz) - real :: db_local(self % filters) - real :: dw_local(self % filters, self % channels, self % kernel_size) - - ! Determine dimensions. input_channels = size(input, dim=1) input_width = size(input, dim=2) - output_width = self % width ! Note: output_width = input_width - kernel_size + 1 - - half_window = self % kernel_size / 2 + output_width = self % width - !--- Compute the local gradient gdz = (dL/dy) * sigma'(z) for each output. do j = 1, output_width gdz(:, j) = gradient(:, j) * self % activation % eval_prime(self % z(:, j)) end do - !--- Compute bias gradients: db(n) = sum_j gdz(n, j) do n = 1, self % filters - db_local(n) = sum(gdz(n, :)) + do j = 1, output_width + db_local(n, j) = gdz(n, j) + end do end do - !--- Initialize weight gradient and input gradient accumulators. dw_local = 0.0 self % gradient = 0.0 - !--- Accumulate gradients over each output position. - ! In the forward pass the window for output index j was: - ! iws = j, iwe = j + kernel_size - 1. do n = 1, self % filters do j = 1, output_width iws = j iwe = j + self % kernel_size - 1 do k = 1, self % channels - ! Weight gradient: accumulate contribution from the input window. - dw_local(n, k, :) = dw_local(n, k, :) + input(k, iws:iwe) * gdz(n, j) - ! Input gradient: propagate gradient back to the input window. - self % gradient(k, iws:iwe) = self % gradient(k, iws:iwe) + self % kernel(n, k, :) * gdz(n, j) + dw_local(n, j, k, :) = dw_local(n, j, k, :) + input(k, iws:iwe) * gdz(n, j) + self % gradient(k, iws:iwe) = self % gradient(k, iws:iwe) + self % kernel(n, j, k, :) * gdz(n, j) end do end do end do - !--- Update stored gradients. self % dw = self % dw + dw_local self % db = self % db + db_local - end subroutine backward pure module function get_num_params(self) result(num_params) class(locally_connected_1d_layer), intent(in) :: self integer :: num_params - num_params = product(shape(self % kernel)) + size(self % biases) + num_params = product(shape(self % kernel)) + product(shape(self % biases)) end function get_num_params module function get_params(self) result(params) class(locally_connected_1d_layer), intent(in), target :: self real, allocatable :: params(:) - real, pointer :: w_(:) => null() - w_(1:size(self % kernel)) => self % kernel - params = [ w_, self % biases ] + params = [reshape(self % kernel, [size(self % kernel)]), reshape(self % biases, [size(self % biases)])] end function get_params module function get_gradients(self) result(gradients) class(locally_connected_1d_layer), intent(in), target :: self real, allocatable :: gradients(:) - real, pointer :: dw_(:) => null() - dw_(1:size(self % dw)) => self % dw - gradients = [ dw_, self % db ] + gradients = [reshape(self % dw, [size(self % dw)]), reshape(self % db, [size(self % db)])] end function get_gradients module subroutine set_params(self, params) @@ -179,7 +144,7 @@ module subroutine set_params(self, params) self % kernel = reshape(params(:product(shape(self % kernel))), shape(self % kernel)) associate(n => product(shape(self % kernel))) - self % biases = params(n + 1 : n + self % filters) + self % biases = reshape(params(n + 1 :), shape(self % biases)) end associate end subroutine set_params From 9082db898f7ecc2df9f9950c4a695c9c3125d827 Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Sun, 23 Feb 2025 18:22:06 +0100 Subject: [PATCH 10/26] Bug fix --- src/nf/nf_layer_constructors.f90 | 35 ++++++++++++++++---------------- src/nf/nf_network_submodule.f90 | 8 +++----- test/test_conv1d_network.f90 | 5 +++-- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index acd35d8c..5704ec81 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -208,22 +208,21 @@ module function conv2d(filters, kernel_size, activation) result(res) end function conv2d module function locally_connected_1d(filters, kernel_size, activation) result(res) - !! CHANGE THE COMMENTS!!! - !! 2-d convolutional layer constructor. + !! 1-d locally connected network constructor !! - !! This layer is for building 2-d convolutional network. - !! Although the established convention is to call these layers 2-d, - !! the shape of the data is actuall 3-d: image width, image height, + !! This layer is for building 1-d locally connected network. + !! Although the established convention is to call these layers 1-d, + !! the shape of the data is actuall 2-d: image width, !! and the number of channels. - !! A conv2d layer must not be the first layer in the network. + !! A locally connected 1d layer must not be the first layer in the network. !! !! Example: !! !! ``` - !! use nf, only :: conv2d, layer - !! type(layer) :: conv2d_layer - !! conv2d_layer = dense(filters=32, kernel_size=3) - !! conv2d_layer = dense(filters=32, kernel_size=3, activation='relu') + !! use nf, only :: locally_connected_1d, layer + !! type(layer) :: locally_connected_1d_layer + !! locally_connected_1d_layer = dense(filters=32, kernel_size=3) + !! locally_connected_1d_layer = dense(filters=32, kernel_size=3, activation='relu') !! ``` integer, intent(in) :: filters !! Number of filters in the output of the layer @@ -236,17 +235,17 @@ module function locally_connected_1d(filters, kernel_size, activation) result(re end function locally_connected_1d module function maxpool1d(pool_size, stride) result(res) - !! 2-d maxpooling layer constructor. + !! 1-d maxpooling layer constructor. !! - !! This layer is for downscaling other layers, typically `conv2d`. + !! This layer is for downscaling other layers, typically `conv1d`. !! !! Example: !! !! ``` - !! use nf, only :: maxpool2d, layer - !! type(layer) :: maxpool2d_layer - !! maxpool2d_layer = maxpool2d(pool_size=2) - !! maxpool2d_layer = maxpool2d(pool_size=2, stride=3) + !! use nf, only :: maxpool1d, layer + !! type(layer) :: maxpool1d_layer + !! maxpool1d_layer = maxpool1d(pool_size=2) + !! maxpool1d_layer = maxpool1d(pool_size=2, stride=3) !! ``` integer, intent(in) :: pool_size !! Width of the pooling window, commonly 2 @@ -292,9 +291,9 @@ end function reshape module function reshape2d(output_shape) result(res) !! Rank-1 to rank-any reshape layer constructor. - !! Currently implemented is only rank-3 for the output of the reshape. + !! Currently implemented is only rank-2 for the output of the reshape. !! - !! This layer is for connecting 1-d inputs to conv2d or similar layers. + !! This layer is for connecting 1-d inputs to conv1d or similar layers. integer, intent(in) :: output_shape(:) !! Shape of the output type(layer) :: res diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 3837217c..2b085e8f 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -92,9 +92,9 @@ module function network_from_layers(layers) result(res) type is(conv1d_layer) res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] n = n + 1 - !type is(reshape2d_layer) - ! res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] - ! n = n + 1 + type is(reshape2d_layer) + res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] + n = n + 1 class default n = n + 1 end select @@ -163,7 +163,6 @@ module subroutine backward(self, output, loss) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) type is(conv2d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(flatten_layer) if (size(self % layers(n) % layer_shape) == 2) then call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d) @@ -172,7 +171,6 @@ module subroutine backward(self, output, loss) end if type is(maxpool2d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(reshape3d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) type is(linear2d_layer) diff --git a/test/test_conv1d_network.f90 b/test/test_conv1d_network.f90 index cefcf327..417e773f 100644 --- a/test/test_conv1d_network.f90 +++ b/test/test_conv1d_network.f90 @@ -41,7 +41,7 @@ program test_conv1d_network real :: y(1) real :: tolerance = 1e-4 integer :: n - integer, parameter :: num_iterations = 1500 + integer, parameter :: num_iterations = 1000 ! Test training of a minimal constant mapping allocate(sample_input(1, 5)) @@ -60,6 +60,7 @@ program test_conv1d_network call cnn % forward(sample_input) call cnn % backward(y) call cnn % update(optimizer=sgd(learning_rate=1.)) + print *, cnn % predict(sample_input), y if (all(abs(cnn % predict(sample_input) - y) < tolerance)) exit end do @@ -79,7 +80,7 @@ program test_conv1d_network real :: y(1) real :: tolerance = 1e-4 integer :: n - integer, parameter :: num_iterations = 1500 + integer, parameter :: num_iterations = 1000 call random_number(x) y = [0.1234567] From d1cffaeb99f2307d1acb88e9f33036f6d8fcbfd9 Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Mon, 24 Feb 2025 17:05:26 +0100 Subject: [PATCH 11/26] Bug fix --- src/nf/nf_conv1d_layer_submodule.f90 | 2 +- test/CMakeLists.txt | 1 - test/test_conv1d_network.f90 | 155 --------------------------- 3 files changed, 1 insertion(+), 157 deletions(-) delete mode 100644 test/test_conv1d_network.f90 diff --git a/src/nf/nf_conv1d_layer_submodule.f90 b/src/nf/nf_conv1d_layer_submodule.f90 index 48ec7901..c82188d6 100644 --- a/src/nf/nf_conv1d_layer_submodule.f90 +++ b/src/nf/nf_conv1d_layer_submodule.f90 @@ -35,7 +35,7 @@ module subroutine init(self, input_shape) ! Kernel of shape: filters x channels x kernel_size allocate(self % kernel(self % filters, self % channels, self % kernel_size)) call random_normal(self % kernel) - self % kernel = self % kernel / real(self % kernel_size**2) + self % kernel = self % kernel / self % kernel_size**2 allocate(self % biases(self % filters)) self % biases = 0 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 6f8e45fc..1aef81fc 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -7,7 +7,6 @@ foreach(execid parametric_activation dense_layer conv1d_layer - conv1d_network conv2d_layer locally_connected_1d_layer maxpool1d_layer diff --git a/test/test_conv1d_network.f90 b/test/test_conv1d_network.f90 deleted file mode 100644 index 417e773f..00000000 --- a/test/test_conv1d_network.f90 +++ /dev/null @@ -1,155 +0,0 @@ -program test_conv1d_network - - use iso_fortran_env, only: stderr => error_unit - use nf, only: conv1d, input, network, dense, sgd, maxpool1d - - implicit none - - type(network) :: net - real, allocatable :: sample_input(:,:), output(:,:) - logical :: ok = .true. - - ! 3-layer convolutional network - net = network([ & - input(3, 32), & - conv1d(filters=16, kernel_size=3), & - conv1d(filters=32, kernel_size=3) & - ]) - - if (.not. size(net % layers) == 3) then - write(stderr, '(a)') 'conv1d network should have 3 layers.. failed' - ok = .false. - end if - - ! Test for output shape - allocate(sample_input(3, 32)) - sample_input = 0 - - call net % forward(sample_input) - call net % layers(3) % get_output(output) - - if (.not. all(shape(output) == [32, 28])) then - write(stderr, '(a)') 'conv1d network output should have correct shape.. failed' - ok = .false. - end if - - deallocate(sample_input, output) - - training1: block - - type(network) :: cnn - real :: y(1) - real :: tolerance = 1e-4 - integer :: n - integer, parameter :: num_iterations = 1000 - - ! Test training of a minimal constant mapping - allocate(sample_input(1, 5)) - call random_number(sample_input) - - cnn = network([ & - input(1, 5), & - conv1d(filters=1, kernel_size=3), & - conv1d(filters=1, kernel_size=3), & - dense(1) & - ]) - - y = [0.1234567] - - do n = 1, num_iterations - call cnn % forward(sample_input) - call cnn % backward(y) - call cnn % update(optimizer=sgd(learning_rate=1.)) - print *, cnn % predict(sample_input), y - if (all(abs(cnn % predict(sample_input) - y) < tolerance)) exit - end do - - if (.not. n <= num_iterations) then - - write(stderr, '(a)') & - 'convolutional network 1 should converge in simple training.. failed' - ok = .false. - end if - - end block training1 - - training2: block - - type(network) :: cnn - real :: x(1, 8) - real :: y(1) - real :: tolerance = 1e-4 - integer :: n - integer, parameter :: num_iterations = 1000 - - call random_number(x) - y = [0.1234567] - - cnn = network([ & - input(1, 8), & - conv1d(filters=1, kernel_size=3), & - maxpool1d(pool_size=2), & - conv1d(filters=1, kernel_size=3), & - dense(1) & - ]) - - do n = 1, num_iterations - call cnn % forward(x) - call cnn % backward(y) - call cnn % update(optimizer=sgd(learning_rate=1.)) - if (all(abs(cnn % predict(x) - y) < tolerance)) exit - end do - - if (.not. n <= num_iterations) then - write(stderr, '(a)') & - 'convolutional network 2 should converge in simple training.. failed' - ok = .false. - end if - - end block training2 - - training3: block - - type(network) :: cnn - real :: x(1, 12) - real :: y(9) - real :: tolerance = 1e-4 - integer :: n - integer, parameter :: num_iterations = 5000 - - call random_number(x) - y = [0.12345, 0.23456, 0.34567, 0.45678, 0.56789, 0.67890, 0.78901, 0.89012, 0.90123] - - cnn = network([ & - input(1, 12), & - conv1d(filters=1, kernel_size=3), & ! 1x12 input, 1x10 output - maxpool1d(pool_size=2), & ! 1x10 input, 1x5 output - conv1d(filters=1, kernel_size=3), & ! 1x5 input, 1x3 output - dense(9) & ! 9 outputs - ]) - - do n = 1, num_iterations - call cnn % forward(x) - call cnn % backward(y) - call cnn % update(optimizer=sgd(learning_rate=1.)) - if (all(abs(cnn % predict(x) - y) < tolerance)) exit - end do - - if (.not. n <= num_iterations) then - write(stderr, '(a)') & - 'convolutional network 3 should converge in simple training.. failed' - ok = .false. - end if - - end block training3 - - - if (ok) then - print '(a)', 'test_conv1d_network: All tests passed.' - else - write(stderr, '(a)') 'test_conv1d_network: One or more tests failed.' - stop 1 - end if - - end program test_conv1d_network - \ No newline at end of file From a055b20a717b6a1733b149a840c35a14d534f065 Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Tue, 25 Feb 2025 16:24:22 +0100 Subject: [PATCH 12/26] New bugs --- src/nf/nf_conv1d_layer_submodule.f90 | 8 +- src/nf/nf_conv2d_layer_submodule.f90 | 6 +- test/CMakeLists.txt | 1 + test/test_conv1d_network.f90 | 155 +++++++++++++++++++++++++++ test/test_conv2d_network.f90 | 4 +- 5 files changed, 166 insertions(+), 8 deletions(-) create mode 100644 test/test_conv1d_network.f90 diff --git a/src/nf/nf_conv1d_layer_submodule.f90 b/src/nf/nf_conv1d_layer_submodule.f90 index c82188d6..97508f57 100644 --- a/src/nf/nf_conv1d_layer_submodule.f90 +++ b/src/nf/nf_conv1d_layer_submodule.f90 @@ -35,7 +35,7 @@ module subroutine init(self, input_shape) ! Kernel of shape: filters x channels x kernel_size allocate(self % kernel(self % filters, self % channels, self % kernel_size)) call random_normal(self % kernel) - self % kernel = self % kernel / self % kernel_size**2 + self % kernel = self % kernel / self % kernel_size allocate(self % biases(self % filters)) self % biases = 0 @@ -124,7 +124,7 @@ pure module subroutine backward(self, input, gradient) !--- Initialize weight gradient and input gradient accumulators. dw_local = 0.0 self % gradient = 0.0 - + !--- Accumulate gradients over each output position. ! In the forward pass the window for output index j was: ! iws = j, iwe = j + kernel_size - 1. @@ -157,8 +157,8 @@ module function get_params(self) result(params) class(conv1d_layer), intent(in), target :: self real, allocatable :: params(:) real, pointer :: w_(:) => null() - w_(1:size(self % kernel)) => self % kernel - params = [ w_, self % biases ] + w_(1:size(self % z)) => self % z + params = [ w_] end function get_params module function get_gradients(self) result(gradients) diff --git a/src/nf/nf_conv2d_layer_submodule.f90 b/src/nf/nf_conv2d_layer_submodule.f90 index 45a2c1da..24a381f2 100644 --- a/src/nf/nf_conv2d_layer_submodule.f90 +++ b/src/nf/nf_conv2d_layer_submodule.f90 @@ -195,11 +195,11 @@ module function get_params(self) result(params) real, pointer :: w_(:) => null() - w_(1:size(self % kernel)) => self % kernel + w_(1:size(self % z)) => self % z params = [ & - w_, & - self % biases & + w_ & + !self % biases & ] end function get_params diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1aef81fc..d1f87e04 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -18,6 +18,7 @@ foreach(execid multihead_attention_layer dense_network get_set_network_params + conv1d_network conv2d_network optimizers loss diff --git a/test/test_conv1d_network.f90 b/test/test_conv1d_network.f90 new file mode 100644 index 00000000..0c3893ae --- /dev/null +++ b/test/test_conv1d_network.f90 @@ -0,0 +1,155 @@ +program test_conv1d_network + + use iso_fortran_env, only: stderr => error_unit + use nf, only: conv1d, input, network, dense, sgd, maxpool1d + + implicit none + + type(network) :: net + real, allocatable :: sample_input(:,:), output(:,:), o(:) + logical :: ok = .true. + + ! 3-layer convolutional network + net = network([ & + input(3, 32), & + conv1d(filters=16, kernel_size=3), & + conv1d(filters=32, kernel_size=3) & + ]) + + if (.not. size(net % layers) == 3) then + write(stderr, '(a)') 'conv2d network should have 3 layers.. failed' + ok = .false. + end if + + ! Test for output shape + allocate(sample_input(3, 32)) + sample_input = 0 + + call net % forward(sample_input) + call net % layers(3) % get_output(output) + + if (.not. all(shape(output) == [32, 28])) then + write(stderr, '(a)') 'conv1d network output should have correct shape.. failed' + ok = .false. + end if + + deallocate(sample_input, output) + + training1: block + + type(network) :: cnn + real :: y(1) + real :: tolerance = 1e-4 + integer :: n + integer, parameter :: num_iterations = 1000 + + ! Test training of a minimal constant mapping + allocate(sample_input(1, 5)) + call random_number(sample_input) + + cnn = network([ & + input(1, 5), & + conv1d(filters=1, kernel_size=3), & + conv1d(filters=1, kernel_size=3), & + dense(1) & + ]) + + y = [0.1234567] + + do n = 1, num_iterations + call cnn % forward(sample_input) + call cnn % backward(y) + call cnn % update(optimizer=sgd(learning_rate=1.)) + o = cnn % layers(2) % get_params() + print *, o + if (all(abs(cnn % predict(sample_input) - y) < tolerance)) exit + end do + + if (.not. n <= num_iterations) then + write(stderr, '(a)') & + 'convolutional network 1 should converge in simple training.. failed' + ok = .false. + end if + + end block training1 + + training2: block + + type(network) :: cnn + real :: x(1, 8) + real :: y(1) + real :: tolerance = 1e-4 + integer :: n + integer, parameter :: num_iterations = 1000 + + call random_number(x) + y = [0.1234567] + + cnn = network([ & + input(1, 8), & + conv1d(filters=1, kernel_size=3), & + maxpool1d(pool_size=2), & + conv1d(filters=1, kernel_size=3), & + dense(1) & + ]) + + do n = 1, num_iterations + call cnn % forward(x) + call cnn % backward(y) + call cnn % update(optimizer=sgd(learning_rate=1.)) + if (all(abs(cnn % predict(x) - y) < tolerance)) exit + end do + + if (.not. n <= num_iterations) then + write(stderr, '(a)') & + 'convolutional network 2 should converge in simple training.. failed' + ok = .false. + end if + + end block training2 + + training3: block + + type(network) :: cnn + real :: x(1, 12) + real :: y(9) + real :: tolerance = 1e-4 + integer :: n + integer, parameter :: num_iterations = 5000 + + call random_number(x) + y = [0.12345, 0.23456, 0.34567, 0.45678, 0.56789, 0.67890, 0.78901, 0.89012, 0.90123] + + cnn = network([ & + input(1, 12), & + conv1d(filters=1, kernel_size=3), & ! 1x12x12 input, 1x10x10 output + maxpool1d(pool_size=2), & ! 1x10x10 input, 1x5x5 output + conv1d(filters=1, kernel_size=3), & ! 1x5x5 input, 1x3x3 output + dense(9) & ! 9 outputs + ]) + + do n = 1, num_iterations + call cnn % forward(x) + call cnn % backward(y) + call cnn % update(optimizer=sgd(learning_rate=1.)) + if (all(abs(cnn % predict(x) - y) < tolerance)) exit + end do + + if (.not. n <= num_iterations) then + write(stderr, '(a)') & + 'convolutional network 3 should converge in simple training.. failed' + ok = .false. + end if + + end block training3 + + + if (ok) then + print '(a)', 'test_conv1d_network: All tests passed.' + else + write(stderr, '(a)') 'test_conv1d_network: One or more tests failed.' + stop 1 + end if + + end program test_conv1d_network + \ No newline at end of file diff --git a/test/test_conv2d_network.f90 b/test/test_conv2d_network.f90 index 1bdfc677..a612539f 100644 --- a/test/test_conv2d_network.f90 +++ b/test/test_conv2d_network.f90 @@ -6,7 +6,7 @@ program test_conv2d_network implicit none type(network) :: net - real, allocatable :: sample_input(:,:,:), output(:,:,:) + real, allocatable :: sample_input(:,:,:), output(:,:,:), o(:) logical :: ok = .true. ! 3-layer convolutional network @@ -60,6 +60,8 @@ program test_conv2d_network call cnn % forward(sample_input) call cnn % backward(y) call cnn % update(optimizer=sgd(learning_rate=1.)) + o = cnn % layers(2) % get_params() + print *, o if (all(abs(cnn % predict(sample_input) - y) < tolerance)) exit end do From c6b4d878b8dee85bd0f0286586b3d4e81f6bd6e7 Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Tue, 25 Feb 2025 16:41:25 +0100 Subject: [PATCH 13/26] Bug fix --- src/nf/nf_conv1d_layer_submodule.f90 | 4 ++-- src/nf/nf_conv2d_layer_submodule.f90 | 6 +++--- test/test_conv2d_network.f90 | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/nf/nf_conv1d_layer_submodule.f90 b/src/nf/nf_conv1d_layer_submodule.f90 index 97508f57..ce64e557 100644 --- a/src/nf/nf_conv1d_layer_submodule.f90 +++ b/src/nf/nf_conv1d_layer_submodule.f90 @@ -157,8 +157,8 @@ module function get_params(self) result(params) class(conv1d_layer), intent(in), target :: self real, allocatable :: params(:) real, pointer :: w_(:) => null() - w_(1:size(self % z)) => self % z - params = [ w_] + w_(1:size(self % kernel)) => self % kernel + params = [ w_, self % biases] end function get_params module function get_gradients(self) result(gradients) diff --git a/src/nf/nf_conv2d_layer_submodule.f90 b/src/nf/nf_conv2d_layer_submodule.f90 index 24a381f2..45a2c1da 100644 --- a/src/nf/nf_conv2d_layer_submodule.f90 +++ b/src/nf/nf_conv2d_layer_submodule.f90 @@ -195,11 +195,11 @@ module function get_params(self) result(params) real, pointer :: w_(:) => null() - w_(1:size(self % z)) => self % z + w_(1:size(self % kernel)) => self % kernel params = [ & - w_ & - !self % biases & + w_, & + self % biases & ] end function get_params diff --git a/test/test_conv2d_network.f90 b/test/test_conv2d_network.f90 index a612539f..84c07a9f 100644 --- a/test/test_conv2d_network.f90 +++ b/test/test_conv2d_network.f90 @@ -60,7 +60,7 @@ program test_conv2d_network call cnn % forward(sample_input) call cnn % backward(y) call cnn % update(optimizer=sgd(learning_rate=1.)) - o = cnn % layers(2) % get_params() + o = cnn % layers(3) % get_params() print *, o if (all(abs(cnn % predict(sample_input) - y) < tolerance)) exit end do From 2e64151c77ed76d88b288edc2d691723f1acd8bd Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Tue, 25 Feb 2025 19:38:46 +0100 Subject: [PATCH 14/26] Definitive bug fixes --- example/cnn_mnist.f90 | 4 ++-- example/cnn_mnist_1d.f90 | 6 +++--- src/nf/nf_conv1d_layer_submodule.f90 | 4 ++-- src/nf/nf_datasets_mnist_submodule.f90 | 6 +++--- src/nf/nf_network.f90 | 1 + src/nf/nf_network_submodule.f90 | 14 ++++++++++++-- test/test_conv1d_network.f90 | 3 +-- 7 files changed, 24 insertions(+), 14 deletions(-) diff --git a/example/cnn_mnist.f90 b/example/cnn_mnist.f90 index bf918c8b..ef22f986 100644 --- a/example/cnn_mnist.f90 +++ b/example/cnn_mnist.f90 @@ -12,7 +12,7 @@ program cnn_mnist real, allocatable :: validation_images(:,:), validation_labels(:) real, allocatable :: testing_images(:,:), testing_labels(:) integer :: n - integer, parameter :: num_epochs = 10 + integer, parameter :: num_epochs = 250 call load_mnist(training_images, training_labels, & validation_images, validation_labels, & @@ -37,7 +37,7 @@ program cnn_mnist label_digits(training_labels), & batch_size=16, & epochs=1, & - optimizer=sgd(learning_rate=0.003) & + optimizer=sgd(learning_rate=0.001) & ) print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( & diff --git a/example/cnn_mnist_1d.f90 b/example/cnn_mnist_1d.f90 index 3a80a81a..8368db7a 100644 --- a/example/cnn_mnist_1d.f90 +++ b/example/cnn_mnist_1d.f90 @@ -1,7 +1,7 @@ program cnn_mnist_1d use nf, only: network, sgd, & - input, conv1d, conv2d, maxpool1d, maxpool2d, flatten, dense, reshape, reshape2d, locally_connected_1d, & + input, conv1d, maxpool1d, flatten, dense, reshape, reshape2d, locally_connected_1d, & load_mnist, label_digits, softmax, relu implicit none @@ -12,7 +12,7 @@ program cnn_mnist_1d real, allocatable :: validation_images(:,:), validation_labels(:) real, allocatable :: testing_images(:,:), testing_labels(:) integer :: n - integer, parameter :: num_epochs = 25 + integer, parameter :: num_epochs = 250 call load_mnist(training_images, training_labels, & validation_images, validation_labels, & @@ -37,7 +37,7 @@ program cnn_mnist_1d label_digits(training_labels), & batch_size=16, & epochs=1, & - optimizer=sgd(learning_rate=0.005) & + optimizer=sgd(learning_rate=0.01) & ) print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( & diff --git a/src/nf/nf_conv1d_layer_submodule.f90 b/src/nf/nf_conv1d_layer_submodule.f90 index ce64e557..43cb690f 100644 --- a/src/nf/nf_conv1d_layer_submodule.f90 +++ b/src/nf/nf_conv1d_layer_submodule.f90 @@ -59,7 +59,7 @@ pure module subroutine forward(self, input) class(conv1d_layer), intent(in out) :: self real, intent(in) :: input(:,:) integer :: input_channels, input_width - integer :: j, n + integer :: j, n, a, b integer :: iws, iwe, half_window input_channels = size(input, dim=1) @@ -95,7 +95,7 @@ pure module subroutine backward(self, input, gradient) real, intent(in) :: gradient(:,:) integer :: input_channels, input_width, output_width - integer :: j, n, k + integer :: j, n, k, a, b, c integer :: iws, iwe, half_window real :: gdz_val diff --git a/src/nf/nf_datasets_mnist_submodule.f90 b/src/nf/nf_datasets_mnist_submodule.f90 index a0bed0a8..842cafe1 100644 --- a/src/nf/nf_datasets_mnist_submodule.f90 +++ b/src/nf/nf_datasets_mnist_submodule.f90 @@ -50,9 +50,9 @@ module subroutine load_mnist(training_images, training_labels, & real, allocatable, intent(in out), optional :: testing_labels(:) integer, parameter :: dtype = 4, image_size = 784 - integer, parameter :: num_training_images = 500 - integer, parameter :: num_validation_images = 100 - integer, parameter :: num_testing_images = 100 + integer, parameter :: num_training_images = 50000 + integer, parameter :: num_validation_images = 10000 + integer, parameter :: num_testing_images = 10000 logical :: file_exists ! Check if MNIST data is present and download it if not. diff --git a/src/nf/nf_network.f90 b/src/nf/nf_network.f90 index 5916924e..f7be1959 100644 --- a/src/nf/nf_network.f90 +++ b/src/nf/nf_network.f90 @@ -201,6 +201,7 @@ module integer function get_num_params(self) !! Network instance end function get_num_params + module function get_params(self) result(params) !! Get the network parameters (weights and biases). class(network), intent(in) :: self diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 2b085e8f..7a0a65b0 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -460,7 +460,6 @@ module function get_num_params(self) end function get_num_params - module function get_params(self) result(params) class(network), intent(in) :: self real, allocatable :: params(:) @@ -480,7 +479,6 @@ module function get_params(self) result(params) end function get_params - module function get_gradients(self) result(gradients) class(network), intent(in) :: self real, allocatable :: gradients(:) @@ -640,6 +638,12 @@ module subroutine update(self, optimizer, batch_size) type is(conv2d_layer) call co_sum(this_layer % dw) call co_sum(this_layer % db) + type is(conv1d_layer) + call co_sum(this_layer % dw) + call co_sum(this_layer % db) + type is(locally_connected_1d_layer) + call co_sum(this_layer % dw) + call co_sum(this_layer % db) end select end do #endif @@ -657,6 +661,12 @@ module subroutine update(self, optimizer, batch_size) type is(conv2d_layer) this_layer % dw = 0 this_layer % db = 0 + type is(conv1d_layer) + this_layer % dw = 0 + this_layer % db = 0 + type is(locally_connected_1d_layer) + this_layer % dw = 0 + this_layer % db = 0 end select end do diff --git a/test/test_conv1d_network.f90 b/test/test_conv1d_network.f90 index 0c3893ae..bb079655 100644 --- a/test/test_conv1d_network.f90 +++ b/test/test_conv1d_network.f90 @@ -60,8 +60,7 @@ program test_conv1d_network call cnn % forward(sample_input) call cnn % backward(y) call cnn % update(optimizer=sgd(learning_rate=1.)) - o = cnn % layers(2) % get_params() - print *, o + if (all(abs(cnn % predict(sample_input) - y) < tolerance)) exit end do From 2b7c548eb5d439ee109e8e82701191d80ac97c35 Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Tue, 25 Feb 2025 19:43:52 +0100 Subject: [PATCH 15/26] Adding jvdp1's review --- src/nf/nf_conv1d_layer_submodule.f90 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nf/nf_conv1d_layer_submodule.f90 b/src/nf/nf_conv1d_layer_submodule.f90 index 43cb690f..50a06d22 100644 --- a/src/nf/nf_conv1d_layer_submodule.f90 +++ b/src/nf/nf_conv1d_layer_submodule.f90 @@ -118,7 +118,7 @@ pure module subroutine backward(self, input, gradient) !--- Compute bias gradients: db(n) = sum_j gdz(n, j) do n = 1, self % filters - db_local(n) = sum(gdz(n, :)) + db_local(n) = sum(gdz(n, :), dim=1) end do !--- Initialize weight gradient and input gradient accumulators. From be5bb762937784b7582357902902d0db20dd2e1c Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Tue, 25 Feb 2025 20:34:32 +0100 Subject: [PATCH 16/26] Implemented OneAdder's suggestions --- src/nf/nf_locally_connected_1d_submodule.f90 | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/nf/nf_locally_connected_1d_submodule.f90 b/src/nf/nf_locally_connected_1d_submodule.f90 index e3903b54..0589f8bf 100644 --- a/src/nf/nf_locally_connected_1d_submodule.f90 +++ b/src/nf/nf_locally_connected_1d_submodule.f90 @@ -65,7 +65,7 @@ pure module subroutine forward(self, input) do j = 1, self % width iws = j iwe = j + self % kernel_size - 1 - do concurrent (n = 1:self % filters) + do n = 1, self % filters self % z(n, j) = sum(self % kernel(n, j, :, :) * input(:, iws:iwe)) + self % biases(n, j) end do end do @@ -125,13 +125,13 @@ end function get_num_params module function get_params(self) result(params) class(locally_connected_1d_layer), intent(in), target :: self real, allocatable :: params(:) - params = [reshape(self % kernel, [size(self % kernel)]), reshape(self % biases, [size(self % biases)])] + params = [self % kernel, self % biases] end function get_params module function get_gradients(self) result(gradients) class(locally_connected_1d_layer), intent(in), target :: self real, allocatable :: gradients(:) - gradients = [reshape(self % dw, [size(self % dw)]), reshape(self % db, [size(self % db)])] + gradients = [self % dw, self % db] end function get_gradients module subroutine set_params(self, params) From 33a6549ddb24d08def1d443c8ec9846eea167d98 Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Tue, 25 Feb 2025 20:48:22 +0100 Subject: [PATCH 17/26] Deleting useless variables --- src/nf/nf_conv1d_layer_submodule.f90 | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/nf/nf_conv1d_layer_submodule.f90 b/src/nf/nf_conv1d_layer_submodule.f90 index 50a06d22..b36b49b1 100644 --- a/src/nf/nf_conv1d_layer_submodule.f90 +++ b/src/nf/nf_conv1d_layer_submodule.f90 @@ -59,12 +59,11 @@ pure module subroutine forward(self, input) class(conv1d_layer), intent(in out) :: self real, intent(in) :: input(:,:) integer :: input_channels, input_width - integer :: j, n, a, b - integer :: iws, iwe, half_window + integer :: j, n + integer :: iws, iwe input_channels = size(input, dim=1) input_width = size(input, dim=2) - half_window = self % kernel_size / 2 ! Loop over output positions. do j = 1, self % width @@ -95,9 +94,8 @@ pure module subroutine backward(self, input, gradient) real, intent(in) :: gradient(:,:) integer :: input_channels, input_width, output_width - integer :: j, n, k, a, b, c - integer :: iws, iwe, half_window - real :: gdz_val + integer :: j, n, k + integer :: iws, iwe ! Local arrays to accumulate gradients. real :: gdz(self % filters, self % width) ! local gradient (dL/dz) @@ -109,8 +107,6 @@ pure module subroutine backward(self, input, gradient) input_width = size(input, dim=2) output_width = self % width ! Note: output_width = input_width - kernel_size + 1 - half_window = self % kernel_size / 2 - !--- Compute the local gradient gdz = (dL/dy) * sigma'(z) for each output. do j = 1, output_width gdz(:, j) = gradient(:, j) * self % activation % eval_prime(self % z(:, j)) From b69ba9ac5dc3798e58a72fb9c9cbd7e415c9a0ea Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <riccardoorsi@icloud.com> Date: Tue, 25 Feb 2025 20:52:12 +0100 Subject: [PATCH 18/26] again --- test/test_conv1d_network.f90 | 2 +- test/test_conv2d_network.f90 | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_conv1d_network.f90 b/test/test_conv1d_network.f90 index bb079655..5a353cf9 100644 --- a/test/test_conv1d_network.f90 +++ b/test/test_conv1d_network.f90 @@ -6,7 +6,7 @@ program test_conv1d_network implicit none type(network) :: net - real, allocatable :: sample_input(:,:), output(:,:), o(:) + real, allocatable :: sample_input(:,:), output(:,:) logical :: ok = .true. ! 3-layer convolutional network diff --git a/test/test_conv2d_network.f90 b/test/test_conv2d_network.f90 index 84c07a9f..73c4595a 100644 --- a/test/test_conv2d_network.f90 +++ b/test/test_conv2d_network.f90 @@ -6,7 +6,7 @@ program test_conv2d_network implicit none type(network) :: net - real, allocatable :: sample_input(:,:,:), output(:,:,:), o(:) + real, allocatable :: sample_input(:,:,:), output(:,:,:) logical :: ok = .true. ! 3-layer convolutional network @@ -60,8 +60,7 @@ program test_conv2d_network call cnn % forward(sample_input) call cnn % backward(y) call cnn % update(optimizer=sgd(learning_rate=1.)) - o = cnn % layers(3) % get_params() - print *, o + if (all(abs(cnn % predict(sample_input) - y) < tolerance)) exit end do From d4a87e2ed26e2099a057529ead26a8b183dde61a Mon Sep 17 00:00:00 2001 From: Riccardo Orsi <104301293+ricor07@users.noreply.github.com> Date: Thu, 27 Feb 2025 20:11:51 +0100 Subject: [PATCH 19/26] Update src/nf/nf_conv1d_layer_submodule.f90 Co-authored-by: Jeremie Vandenplas <jeremie.vandenplas@gmail.com> --- src/nf/nf_conv1d_layer_submodule.f90 | 1 - 1 file changed, 1 deletion(-) diff --git a/src/nf/nf_conv1d_layer_submodule.f90 b/src/nf/nf_conv1d_layer_submodule.f90 index b36b49b1..0e6d6e63 100644 --- a/src/nf/nf_conv1d_layer_submodule.f90 +++ b/src/nf/nf_conv1d_layer_submodule.f90 @@ -8,7 +8,6 @@ contains module function conv1d_layer_cons(filters, kernel_size, activation) result(res) - implicit none integer, intent(in) :: filters integer, intent(in) :: kernel_size class(activation_function), intent(in) :: activation From 95c4a2ab9e67785d37e6ae25991af011961b0feb Mon Sep 17 00:00:00 2001 From: milancurcic <caomaco@gmail.com> Date: Thu, 13 Mar 2025 14:33:01 -0400 Subject: [PATCH 20/26] locally_connected_1d -> locally_connected1d --- example/cnn_mnist_1d.f90 | 8 +-- src/nf.f90 | 2 +- src/nf/nf_layer_constructors.f90 | 14 ++--- src/nf/nf_layer_constructors_submodule.f90 | 10 ++-- src/nf/nf_layer_submodule.f90 | 46 +++++++-------- ...d.f90 => nf_locally_connected1d_layer.f90} | 56 +++++++++---------- ...f_locally_connected1d_layer_submodule.f90} | 26 ++++----- src/nf/nf_network_submodule.f90 | 10 ++-- ...f90 => test_locally_connected1d_layer.f90} | 30 +++++----- 9 files changed, 101 insertions(+), 101 deletions(-) rename src/nf/{nf_locally_connected_1d.f90 => nf_locally_connected1d_layer.f90} (65%) rename src/nf/{nf_locally_connected_1d_submodule.f90 => nf_locally_connected1d_layer_submodule.f90} (82%) rename test/{test_locally_connected_1d_layer.f90 => test_locally_connected1d_layer.f90} (58%) diff --git a/example/cnn_mnist_1d.f90 b/example/cnn_mnist_1d.f90 index 8368db7a..7e978034 100644 --- a/example/cnn_mnist_1d.f90 +++ b/example/cnn_mnist_1d.f90 @@ -1,7 +1,7 @@ program cnn_mnist_1d use nf, only: network, sgd, & - input, conv1d, maxpool1d, flatten, dense, reshape, reshape2d, locally_connected_1d, & + input, conv1d, maxpool1d, flatten, dense, reshape, reshape2d, locally_connected1d, & load_mnist, label_digits, softmax, relu implicit none @@ -20,10 +20,10 @@ program cnn_mnist_1d net = network([ & input(784), & - reshape2d([28,28]), & - locally_connected_1d(filters=8, kernel_size=3, activation=relu()), & + reshape2d([28, 28]), & + locally_connected1d(filters=8, kernel_size=3, activation=relu()), & maxpool1d(pool_size=2), & - locally_connected_1d(filters=16, kernel_size=3, activation=relu()), & + locally_connected1d(filters=16, kernel_size=3, activation=relu()), & maxpool1d(pool_size=2), & dense(10, activation=softmax()) & ]) diff --git a/src/nf.f90 b/src/nf.f90 index d86c5c56..172eafb3 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -12,7 +12,7 @@ module nf input, & layernorm, & linear2d, & - locally_connected_1d, & + locally_connected1d, & maxpool1d, & maxpool2d, & reshape, & diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index 67dc93c3..e5f92f64 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -16,7 +16,7 @@ module nf_layer_constructors flatten, & input, & linear2d, & - locally_connected_1d, & + locally_connected1d, & maxpool1d, & maxpool2d, & reshape, & @@ -212,7 +212,7 @@ module function conv2d(filters, kernel_size, activation) result(res) !! Resulting layer instance end function conv2d - module function locally_connected_1d(filters, kernel_size, activation) result(res) + module function locally_connected1d(filters, kernel_size, activation) result(res) !! 1-d locally connected network constructor !! !! This layer is for building 1-d locally connected network. @@ -224,10 +224,10 @@ module function locally_connected_1d(filters, kernel_size, activation) result(re !! Example: !! !! ``` - !! use nf, only :: locally_connected_1d, layer - !! type(layer) :: locally_connected_1d_layer - !! locally_connected_1d_layer = dense(filters=32, kernel_size=3) - !! locally_connected_1d_layer = dense(filters=32, kernel_size=3, activation='relu') + !! use nf, only :: locally_connected1d, layer + !! type(layer) :: locally_connected1d_layer + !! locally_connected1d_layer = dense(filters=32, kernel_size=3) + !! locally_connected1d_layer = dense(filters=32, kernel_size=3, activation='relu') !! ``` integer, intent(in) :: filters !! Number of filters in the output of the layer @@ -237,7 +237,7 @@ module function locally_connected_1d(filters, kernel_size, activation) result(re !! Activation function (default sigmoid) type(layer) :: res !! Resulting layer instance - end function locally_connected_1d + end function locally_connected1d module function maxpool1d(pool_size, stride) result(res) !! 1-d maxpooling layer constructor. diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index 432e9a21..48fcd8a5 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -9,7 +9,7 @@ use nf_input1d_layer, only: input1d_layer use nf_input2d_layer, only: input2d_layer use nf_input3d_layer, only: input3d_layer - use nf_locally_connected_1d_layer, only: locally_connected_1d_layer + use nf_locally_connected1d_layer, only: locally_connected1d_layer use nf_maxpool1d_layer, only: maxpool1d_layer use nf_maxpool2d_layer, only: maxpool2d_layer use nf_reshape_layer, only: reshape3d_layer @@ -74,7 +74,7 @@ module function conv2d(filters, kernel_size, activation) result(res) end function conv2d - module function locally_connected_1d(filters, kernel_size, activation) result(res) + module function locally_connected1d(filters, kernel_size, activation) result(res) integer, intent(in) :: filters integer, intent(in) :: kernel_size class(activation_function), intent(in), optional :: activation @@ -82,7 +82,7 @@ module function locally_connected_1d(filters, kernel_size, activation) result(re class(activation_function), allocatable :: activation_tmp - res % name = 'locally_connected_1d' + res % name = 'locally_connected1d' if (present(activation)) then allocate(activation_tmp, source=activation) @@ -94,10 +94,10 @@ module function locally_connected_1d(filters, kernel_size, activation) result(re allocate( & res % p, & - source=locally_connected_1d_layer(filters, kernel_size, activation_tmp) & + source=locally_connected1d_layer(filters, kernel_size, activation_tmp) & ) - end function locally_connected_1d + end function locally_connected1d module function dense(layer_size, activation) result(res) diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index ad695602..63af7264 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -9,7 +9,7 @@ use nf_input1d_layer, only: input1d_layer use nf_input2d_layer, only: input2d_layer use nf_input3d_layer, only: input3d_layer - use nf_locally_connected_1d_layer, only: locally_connected_1d_layer + use nf_locally_connected1d_layer, only: locally_connected1d_layer use nf_maxpool1d_layer, only: maxpool1d_layer use nf_maxpool2d_layer, only: maxpool2d_layer use nf_reshape2d_layer, only: reshape2d_layer @@ -52,11 +52,11 @@ pure module subroutine backward_1d(self, previous, gradient) type is(flatten_layer) - ! Upstream layers permitted: input2d, input3d, conv1d, conv2d, locally_connected_1d, maxpool1d, maxpool2d + ! Upstream layers permitted: input2d, input3d, conv1d, conv2d, locally_connected1d, maxpool1d, maxpool2d select type(prev_layer => previous % p) type is(input2d_layer) call this_layer % backward(prev_layer % output, gradient) - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) call this_layer % backward(prev_layer % output, gradient) type is(maxpool1d_layer) call this_layer % backward(prev_layer % output, gradient) @@ -145,13 +145,13 @@ pure module subroutine backward_2d(self, previous, gradient) call this_layer % backward(prev_layer % output, gradient) type is(input2d_layer) call this_layer % backward(prev_layer % output, gradient) - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) call this_layer % backward(prev_layer % output, gradient) type is(conv1d_layer) call this_layer % backward(prev_layer % output, gradient) end select - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) select type(prev_layer => previous % p) type is(maxpool1d_layer) @@ -160,7 +160,7 @@ pure module subroutine backward_2d(self, previous, gradient) call this_layer % backward(prev_layer % output, gradient) type is(input2d_layer) call this_layer % backward(prev_layer % output, gradient) - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) call this_layer % backward(prev_layer % output, gradient) type is(conv1d_layer) call this_layer % backward(prev_layer % output, gradient) @@ -173,7 +173,7 @@ pure module subroutine backward_2d(self, previous, gradient) call this_layer % backward(prev_layer % output, gradient) type is(reshape2d_layer) call this_layer % backward(prev_layer % output, gradient) - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) call this_layer % backward(prev_layer % output, gradient) type is(input2d_layer) call this_layer % backward(prev_layer % output, gradient) @@ -294,13 +294,13 @@ module subroutine forward(self, input) call this_layer % forward(prev_layer % output) end select - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) - ! Upstream layers permitted: input2d, locally_connected_1d, maxpool1d, reshape2d + ! Upstream layers permitted: input2d, locally_connected1d, maxpool1d, reshape2d select type(prev_layer => input % p) type is(input2d_layer) call this_layer % forward(prev_layer % output) - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) call this_layer % forward(prev_layer % output) type is(maxpool1d_layer) call this_layer % forward(prev_layer % output) @@ -312,11 +312,11 @@ module subroutine forward(self, input) type is(conv1d_layer) - ! Upstream layers permitted: input2d, locally_connected_1d, maxpool1d, reshape2d + ! Upstream layers permitted: input2d, locally_connected1d, maxpool1d, reshape2d select type(prev_layer => input % p) type is(input2d_layer) call this_layer % forward(prev_layer % output) - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) call this_layer % forward(prev_layer % output) type is(maxpool1d_layer) call this_layer % forward(prev_layer % output) @@ -328,11 +328,11 @@ module subroutine forward(self, input) type is(maxpool1d_layer) - ! Upstream layers permitted: input1d, locally_connected_1d, maxpool1d, reshape2d + ! Upstream layers permitted: input1d, locally_connected1d, maxpool1d, reshape2d select type(prev_layer => input % p) type is(input2d_layer) call this_layer % forward(prev_layer % output) - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) call this_layer % forward(prev_layer % output) type is(maxpool1d_layer) call this_layer % forward(prev_layer % output) @@ -358,7 +358,7 @@ module subroutine forward(self, input) type is(flatten_layer) - ! Upstream layers permitted: input2d, input3d, conv2d, maxpool1d, maxpool2d, reshape2d, reshape3d, locally_connected_2d + ! Upstream layers permitted: input2d, input3d, conv2d, maxpool1d, maxpool2d, reshape2d, reshape3d, locally_connected2d select type(prev_layer => input % p) type is(input2d_layer) call this_layer % forward(prev_layer % output) @@ -368,7 +368,7 @@ module subroutine forward(self, input) call this_layer % forward(prev_layer % output) type is(conv2d_layer) call this_layer % forward(prev_layer % output) - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) call this_layer % forward(prev_layer % output) type is(maxpool1d_layer) call this_layer % forward(prev_layer % output) @@ -481,7 +481,7 @@ pure module subroutine get_output_2d(self, output) allocate(output, source=this_layer % output) type is(maxpool1d_layer) allocate(output, source=this_layer % output) - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) allocate(output, source=this_layer % output) type is(conv1d_layer) allocate(output, source=this_layer % output) @@ -497,7 +497,7 @@ pure module subroutine get_output_2d(self, output) allocate(output, source=this_layer % output) class default error stop '2-d output can only be read from a input2d, maxpool1d, ' & - // 'locally_connected_1d, conv1d, reshape2d, embedding, linear2d, ' & + // 'locally_connected1d, conv1d, reshape2d, embedding, linear2d, ' & // 'self_attention, or layernorm layer.' end select @@ -549,7 +549,7 @@ impure elemental module subroutine init(self, input) self % layer_shape = shape(this_layer % output) type is(dropout_layer) self % layer_shape = shape(this_layer % output) - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) self % layer_shape = shape(this_layer % output) type is(maxpool1d_layer) self % layer_shape = shape(this_layer % output) @@ -611,7 +611,7 @@ elemental module function get_num_params(self) result(num_params) num_params = this_layer % get_num_params() type is (conv2d_layer) num_params = this_layer % get_num_params() - type is (locally_connected_1d_layer) + type is (locally_connected1d_layer) num_params = this_layer % get_num_params() type is (maxpool1d_layer) num_params = 0 @@ -656,7 +656,7 @@ module function get_params(self) result(params) params = this_layer % get_params() type is (conv2d_layer) params = this_layer % get_params() - type is (locally_connected_1d_layer) + type is (locally_connected1d_layer) params = this_layer % get_params() type is (maxpool1d_layer) ! No parameters to get. @@ -701,7 +701,7 @@ module function get_gradients(self) result(gradients) gradients = this_layer % get_gradients() type is (conv2d_layer) gradients = this_layer % get_gradients() - type is (locally_connected_1d_layer) + type is (locally_connected1d_layer) gradients = this_layer % get_gradients() type is (maxpool1d_layer) ! No gradients to get. @@ -776,7 +776,7 @@ module subroutine set_params(self, params) type is (conv2d_layer) call this_layer % set_params(params) - type is (locally_connected_1d_layer) + type is (locally_connected1d_layer) call this_layer % set_params(params) type is (maxpool1d_layer) diff --git a/src/nf/nf_locally_connected_1d.f90 b/src/nf/nf_locally_connected1d_layer.f90 similarity index 65% rename from src/nf/nf_locally_connected_1d.f90 rename to src/nf/nf_locally_connected1d_layer.f90 index 1dc3b4a1..beca76d5 100644 --- a/src/nf/nf_locally_connected_1d.f90 +++ b/src/nf/nf_locally_connected1d_layer.f90 @@ -1,14 +1,14 @@ -module nf_locally_connected_1d_layer - !! This modules provides a 1-d convolutional `locally_connected_1d` type. +module nf_locally_connected1d_layer + !! This modules provides a 1-d convolutional `locally_connected1d` type. use nf_activation, only: activation_function use nf_base_layer, only: base_layer implicit none private - public :: locally_connected_1d_layer + public :: locally_connected1d_layer - type, extends(base_layer) :: locally_connected_1d_layer + type, extends(base_layer) :: locally_connected1d_layer integer :: width integer :: height @@ -37,18 +37,18 @@ module nf_locally_connected_1d_layer procedure :: init procedure :: set_params - end type locally_connected_1d_layer + end type locally_connected1d_layer - interface locally_connected_1d_layer - module function locally_connected_1d_layer_cons(filters, kernel_size, activation) & + interface locally_connected1d_layer + module function locally_connected1d_layer_cons(filters, kernel_size, activation) & result(res) - !! `locally_connected_1d_layer` constructor function + !! `locally_connected1d_layer` constructor function integer, intent(in) :: filters integer, intent(in) :: kernel_size class(activation_function), intent(in) :: activation - type(locally_connected_1d_layer) :: res - end function locally_connected_1d_layer_cons - end interface locally_connected_1d_layer + type(locally_connected1d_layer) :: res + end function locally_connected1d_layer_cons + end interface locally_connected1d_layer interface @@ -56,24 +56,24 @@ module subroutine init(self, input_shape) !! Initialize the layer data structures. !! !! This is a deferred procedure from the `base_layer` abstract type. - class(locally_connected_1d_layer), intent(in out) :: self - !! A `locally_connected_1d_layer` instance + class(locally_connected1d_layer), intent(in out) :: self + !! A `locally_connected1d_layer` instance integer, intent(in) :: input_shape(:) !! Input layer dimensions end subroutine init pure module subroutine forward(self, input) - !! Apply a forward pass on the `locally_connected_1d` layer. - class(locally_connected_1d_layer), intent(in out) :: self - !! A `locally_connected_1d_layer` instance + !! Apply a forward pass on the `locally_connected1d` layer. + class(locally_connected1d_layer), intent(in out) :: self + !! A `locally_connected1d_layer` instance real, intent(in) :: input(:,:) !! Input data end subroutine forward pure module subroutine backward(self, input, gradient) - !! Apply a backward pass on the `locally_connected_1d` layer. - class(locally_connected_1d_layer), intent(in out) :: self - !! A `locally_connected_1d_layer` instance + !! Apply a backward pass on the `locally_connected1d` layer. + class(locally_connected1d_layer), intent(in out) :: self + !! A `locally_connected1d_layer` instance real, intent(in) :: input(:,:) !! Input data (previous layer) real, intent(in) :: gradient(:,:) @@ -82,8 +82,8 @@ end subroutine backward pure module function get_num_params(self) result(num_params) !! Get the number of parameters in the layer. - class(locally_connected_1d_layer), intent(in) :: self - !! A `locally_connected_1d_layer` instance + class(locally_connected1d_layer), intent(in) :: self + !! A `locally_connected1d_layer` instance integer :: num_params !! Number of parameters end function get_num_params @@ -91,8 +91,8 @@ end function get_num_params module function get_params(self) result(params) !! Return the parameters (weights and biases) of this layer. !! The parameters are ordered as weights first, biases second. - class(locally_connected_1d_layer), intent(in), target :: self - !! A `locally_connected_1d_layer` instance + class(locally_connected1d_layer), intent(in), target :: self + !! A `locally_connected1d_layer` instance real, allocatable :: params(:) !! Parameters to get end function get_params @@ -100,20 +100,20 @@ end function get_params module function get_gradients(self) result(gradients) !! Return the gradients of this layer. !! The gradients are ordered as weights first, biases second. - class(locally_connected_1d_layer), intent(in), target :: self - !! A `locally_connected_1d_layer` instance + class(locally_connected1d_layer), intent(in), target :: self + !! A `locally_connected1d_layer` instance real, allocatable :: gradients(:) !! Gradients to get end function get_gradients module subroutine set_params(self, params) !! Set the parameters of the layer. - class(locally_connected_1d_layer), intent(in out) :: self - !! A `locally_connected_1d_layer` instance + class(locally_connected1d_layer), intent(in out) :: self + !! A `locally_connected1d_layer` instance real, intent(in) :: params(:) !! Parameters to set end subroutine set_params end interface -end module nf_locally_connected_1d_layer +end module nf_locally_connected1d_layer diff --git a/src/nf/nf_locally_connected_1d_submodule.f90 b/src/nf/nf_locally_connected1d_layer_submodule.f90 similarity index 82% rename from src/nf/nf_locally_connected_1d_submodule.f90 rename to src/nf/nf_locally_connected1d_layer_submodule.f90 index 0589f8bf..053c520b 100644 --- a/src/nf/nf_locally_connected_1d_submodule.f90 +++ b/src/nf/nf_locally_connected1d_layer_submodule.f90 @@ -1,4 +1,4 @@ -submodule(nf_locally_connected_1d_layer) nf_locally_connected_1d_layer_submodule +submodule(nf_locally_connected1d_layer) nf_locally_connected1d_layer_submodule use nf_activation, only: activation_function use nf_random, only: random_normal @@ -7,22 +7,22 @@ contains - module function locally_connected_1d_layer_cons(filters, kernel_size, activation) result(res) + module function locally_connected1d_layer_cons(filters, kernel_size, activation) result(res) implicit none integer, intent(in) :: filters integer, intent(in) :: kernel_size class(activation_function), intent(in) :: activation - type(locally_connected_1d_layer) :: res + type(locally_connected1d_layer) :: res res % kernel_size = kernel_size res % filters = filters res % activation_name = activation % get_name() allocate(res % activation, source = activation) - end function locally_connected_1d_layer_cons + end function locally_connected1d_layer_cons module subroutine init(self, input_shape) implicit none - class(locally_connected_1d_layer), intent(in out) :: self + class(locally_connected1d_layer), intent(in out) :: self integer, intent(in) :: input_shape(:) self % channels = input_shape(1) @@ -53,7 +53,7 @@ end subroutine init pure module subroutine forward(self, input) implicit none - class(locally_connected_1d_layer), intent(in out) :: self + class(locally_connected1d_layer), intent(in out) :: self real, intent(in) :: input(:,:) integer :: input_channels, input_width integer :: j, n @@ -74,7 +74,7 @@ end subroutine forward pure module subroutine backward(self, input, gradient) implicit none - class(locally_connected_1d_layer), intent(in out) :: self + class(locally_connected1d_layer), intent(in out) :: self real, intent(in) :: input(:,:) real, intent(in) :: gradient(:,:) integer :: input_channels, input_width, output_width @@ -117,29 +117,29 @@ pure module subroutine backward(self, input, gradient) end subroutine backward pure module function get_num_params(self) result(num_params) - class(locally_connected_1d_layer), intent(in) :: self + class(locally_connected1d_layer), intent(in) :: self integer :: num_params num_params = product(shape(self % kernel)) + product(shape(self % biases)) end function get_num_params module function get_params(self) result(params) - class(locally_connected_1d_layer), intent(in), target :: self + class(locally_connected1d_layer), intent(in), target :: self real, allocatable :: params(:) params = [self % kernel, self % biases] end function get_params module function get_gradients(self) result(gradients) - class(locally_connected_1d_layer), intent(in), target :: self + class(locally_connected1d_layer), intent(in), target :: self real, allocatable :: gradients(:) gradients = [self % dw, self % db] end function get_gradients module subroutine set_params(self, params) - class(locally_connected_1d_layer), intent(in out) :: self + class(locally_connected1d_layer), intent(in out) :: self real, intent(in) :: params(:) if (size(params) /= self % get_num_params()) then - error stop 'locally_connected_1d_layer % set_params: Number of parameters does not match' + error stop 'locally_connected1d_layer % set_params: Number of parameters does not match' end if self % kernel = reshape(params(:product(shape(self % kernel))), shape(self % kernel)) @@ -149,4 +149,4 @@ module subroutine set_params(self, params) end subroutine set_params -end submodule nf_locally_connected_1d_layer_submodule +end submodule nf_locally_connected1d_layer_submodule diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 2dd91261..1752f1f8 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -8,7 +8,7 @@ use nf_input1d_layer, only: input1d_layer use nf_input2d_layer, only: input2d_layer use nf_input3d_layer, only: input3d_layer - use nf_locally_connected_1d_layer, only: locally_connected_1d_layer + use nf_locally_connected1d_layer, only: locally_connected1d_layer use nf_maxpool1d_layer, only: maxpool1d_layer use nf_maxpool2d_layer, only: maxpool2d_layer use nf_reshape2d_layer, only: reshape2d_layer @@ -79,7 +79,7 @@ module function network_from_layers(layers) result(res) type is(conv2d_layer) res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] n = n + 1 - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) res % layers = [res % layers(:n-1), flatten(), res % layers(n:)] n = n + 1 type is(maxpool2d_layer) @@ -185,7 +185,7 @@ module subroutine backward(self, output, loss) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) type is(conv1d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) type is(layernorm_layer) call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) @@ -686,7 +686,7 @@ module subroutine update(self, optimizer, batch_size) type is(conv1d_layer) call co_sum(this_layer % dw) call co_sum(this_layer % db) - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) call co_sum(this_layer % dw) call co_sum(this_layer % db) end select @@ -709,7 +709,7 @@ module subroutine update(self, optimizer, batch_size) type is(conv1d_layer) this_layer % dw = 0 this_layer % db = 0 - type is(locally_connected_1d_layer) + type is(locally_connected1d_layer) this_layer % dw = 0 this_layer % db = 0 end select diff --git a/test/test_locally_connected_1d_layer.f90 b/test/test_locally_connected1d_layer.f90 similarity index 58% rename from test/test_locally_connected_1d_layer.f90 rename to test/test_locally_connected1d_layer.f90 index 50489128..e8a30cfc 100644 --- a/test/test_locally_connected_1d_layer.f90 +++ b/test/test_locally_connected1d_layer.f90 @@ -1,7 +1,7 @@ -program test_locally_connected_1d_layer +program test_locally_connected1d_layer use iso_fortran_env, only: stderr => error_unit - use nf, only: locally_connected_1d, input, layer + use nf, only: locally_connected1d, input, layer use nf_input2d_layer, only: input2d_layer implicit none @@ -12,21 +12,21 @@ program test_locally_connected_1d_layer real, parameter :: tolerance = 1e-7 logical :: ok = .true. - locally_connected_1d_layer = locally_connected_1d(filters, kernel_size) + locally_connected_1d_layer = locally_connected1d(filters, kernel_size) - if (.not. locally_connected_1d_layer % name == 'locally_connected_1d') then + if (.not. locally_connected_1d_layer % name == 'locally_connected1d') then ok = .false. - write(stderr, '(a)') 'locally_connected_1d layer has its name set correctly.. failed' + write(stderr, '(a)') 'locally_connected1d layer has its name set correctly.. failed' end if if (locally_connected_1d_layer % initialized) then ok = .false. - write(stderr, '(a)') 'locally_connected_1d layer should not be marked as initialized yet.. failed' + write(stderr, '(a)') 'locally_connected1d layer should not be marked as initialized yet.. failed' end if if (.not. locally_connected_1d_layer % activation == 'relu') then ok = .false. - write(stderr, '(a)') 'locally_connected_1d layer defaults to relu activation.. failed' + write(stderr, '(a)') 'locally_connected1d layer defaults to relu activation.. failed' end if input_layer = input(3, 32) @@ -34,17 +34,17 @@ program test_locally_connected_1d_layer if (.not. locally_connected_1d_layer % initialized) then ok = .false. - write(stderr, '(a)') 'locally_connected_1d layer should now be marked as initialized.. failed' + write(stderr, '(a)') 'locally_connected1d layer should now be marked as initialized.. failed' end if if (.not. all(locally_connected_1d_layer % input_layer_shape == [3, 32])) then ok = .false. - write(stderr, '(a)') 'locally_connected_1d layer input layer shape should be correct.. failed' + write(stderr, '(a)') 'locally_connected1d layer input layer shape should be correct.. failed' end if if (.not. all(locally_connected_1d_layer % layer_shape == [filters, 30])) then ok = .false. - write(stderr, '(a)') 'locally_connected_1d layer input layer shape should be correct.. failed' + write(stderr, '(a)') 'locally_connected1d layer input layer shape should be correct.. failed' end if ! Minimal locally_connected_1d layer: 1 channel, 3x3 pixel image; @@ -52,7 +52,7 @@ program test_locally_connected_1d_layer sample_input = 0 input_layer = input(1, 3) - locally_connected_1d_layer = locally_connected_1d(filters, kernel_size) + locally_connected_1d_layer = locally_connected1d(filters, kernel_size) call locally_connected_1d_layer % init(input_layer) select type(this_layer => input_layer % p); type is(input2d_layer) @@ -65,14 +65,14 @@ program test_locally_connected_1d_layer if (.not. all(abs(output) < tolerance)) then ok = .false. - write(stderr, '(a)') 'locally_connected_1d layer with zero input and sigmoid function must forward to all 0.5.. failed' + write(stderr, '(a)') 'locally_connected1d layer with zero input and sigmoid function must forward to all 0.5.. failed' end if if (ok) then - print '(a)', 'test_locally_connected_1d_layer: All tests passed.' + print '(a)', 'test_locally_connected1d_layer: All tests passed.' else - write(stderr, '(a)') 'test_locally_connected_1d_layer: One or more tests failed.' + write(stderr, '(a)') 'test_locally_connected1d_layer: One or more tests failed.' stop 1 end if -end program test_locally_connected_1d_layer +end program test_locally_connected1d_layer From f3daf43803427e720251604874d7ad94a168b659 Mon Sep 17 00:00:00 2001 From: milancurcic <caomaco@gmail.com> Date: Thu, 13 Mar 2025 14:40:01 -0400 Subject: [PATCH 21/26] Update features table --- README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index e94296a3..50824acb 100644 --- a/README.md +++ b/README.md @@ -33,16 +33,18 @@ Read the paper [here](https://arxiv.org/abs/1902.06714). | Embedding | `embedding` | n/a | 2 | ✅ | ✅ | | Dense (fully-connected) | `dense` | `input1d`, `dense`, `dropout`, `flatten` | 1 | ✅ | ✅ | | Dropout | `dropout` | `dense`, `flatten`, `input1d` | 1 | ✅ | ✅ | -| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅(*) | +| Locally connected (1-d) | `locally_connected1d` | `input2d`, `locally_connected1d`, `conv1d`, `maxpool1d`, `reshape2d` | 2 | ✅ | ✅ | +| Convolutional (1-d) | `conv1d` | `input2d`, `conv1d`, `maxpool1d`, `reshape2d` | 2 | ✅ | ✅ | +| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅ | +| Max-pooling (1-d) | `maxpool1d` | `input2d`, `conv1d`, `maxpool1d`, `reshape2d` | 2 | ✅ | ✅ | | Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅ | | Linear (2-d) | `linear2d` | `input2d`, `layernorm`, `linear2d`, `self_attention` | 2 | ✅ | ✅ | | Self-attention | `self_attention` | `input2d`, `layernorm`, `linear2d`, `self_attention` | 2 | ✅ | ✅ | | Layer Normalization | `layernorm` | `linear2d`, `self_attention` | 2 | ✅ | ✅ | | Flatten | `flatten` | `input2d`, `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 | ✅ | ✅ | +| Reshape (1-d to 2-d) | `reshape2d` | `input2d`, `conv1d`, `locally_connected1d`, `maxpool1d` | 2 | ✅ | ✅ | | Reshape (1-d to 3-d) | `reshape` | `input1d`, `dense`, `flatten` | 3 | ✅ | ✅ | -(*) See Issue [#145](https://github.com/modern-fortran/neural-fortran/issues/145) regarding non-converging CNN training on the MNIST dataset. - ## Getting started Get the code: From 7819b551392698388fce3745f30257c3c22ade58 Mon Sep 17 00:00:00 2001 From: milancurcic <caomaco@gmail.com> Date: Thu, 13 Mar 2025 14:43:58 -0400 Subject: [PATCH 22/26] Fix CMakeLists --- test/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 40d56f00..ec4e139e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -8,7 +8,7 @@ foreach(execid dense_layer conv1d_layer conv2d_layer - locally_connected_1d_layer + locally_connected1d_layer maxpool1d_layer maxpool2d_layer flatten_layer From 5b46efb10dcbdd9fcc326b651a3cdb56d76fd0a1 Mon Sep 17 00:00:00 2001 From: milancurcic <caomaco@gmail.com> Date: Thu, 13 Mar 2025 14:45:13 -0400 Subject: [PATCH 23/26] Fix CmakeListst --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 37014771..e56de44b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,8 +43,8 @@ add_library(neural-fortran src/nf/nf_layernorm_submodule.f90 src/nf/nf_layer.f90 src/nf/nf_layer_submodule.f90 - src/nf/nf_locally_connected_1d_submodule.f90 - src/nf/nf_locally_connected_1d.f90 + src/nf/nf_locally_connected1d_submodule.f90 + src/nf/nf_locally_connected1d.f90 src/nf/nf_linear2d_layer.f90 src/nf/nf_linear2d_layer_submodule.f90 src/nf/nf_embedding_layer.f90 From 473fcf3fc413b1475217e812a9a5704882fdf27a Mon Sep 17 00:00:00 2001 From: milancurcic <caomaco@gmail.com> Date: Thu, 13 Mar 2025 14:47:08 -0400 Subject: [PATCH 24/26] Another one --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e56de44b..fcd93915 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,8 +43,8 @@ add_library(neural-fortran src/nf/nf_layernorm_submodule.f90 src/nf/nf_layer.f90 src/nf/nf_layer_submodule.f90 - src/nf/nf_locally_connected1d_submodule.f90 - src/nf/nf_locally_connected1d.f90 + src/nf/nf_locally_connected1d_layer_submodule.f90 + src/nf/nf_locally_connected1d_layer.f90 src/nf/nf_linear2d_layer.f90 src/nf/nf_linear2d_layer_submodule.f90 src/nf/nf_embedding_layer.f90 From 174a4216d122e88dfe42a46acc09da004b2df396 Mon Sep 17 00:00:00 2001 From: milancurcic <caomaco@gmail.com> Date: Thu, 13 Mar 2025 15:10:55 -0400 Subject: [PATCH 25/26] Tidy up --- src/nf/nf_conv1d_layer_submodule.f90 | 38 +++++++++++++--------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/src/nf/nf_conv1d_layer_submodule.f90 b/src/nf/nf_conv1d_layer_submodule.f90 index 0e6d6e63..5404b9c7 100644 --- a/src/nf/nf_conv1d_layer_submodule.f90 +++ b/src/nf/nf_conv1d_layer_submodule.f90 @@ -62,7 +62,7 @@ pure module subroutine forward(self, input) integer :: iws, iwe input_channels = size(input, dim=1) - input_width = size(input, dim=2) + input_width = size(input, dim=2) ! Loop over output positions. do j = 1, self % width @@ -73,11 +73,11 @@ pure module subroutine forward(self, input) ! For each filter, compute the convolution (inner product over channels and kernel width). do concurrent (n = 1:self % filters) - self % z(n, j) = sum(self % kernel(n, :, :) * input(:, iws:iwe)) + self % z(n, j) = sum(self % kernel(n,:,:) * input(:,iws:iwe)) end do ! Add the bias for each filter. - self % z(:, j) = self % z(:, j) + self % biases + self % z(:,j) = self % z(:,j) + self % biases end do ! Apply the activation function. @@ -103,18 +103,14 @@ pure module subroutine backward(self, input, gradient) ! Determine dimensions. input_channels = size(input, dim=1) - input_width = size(input, dim=2) - output_width = self % width ! Note: output_width = input_width - kernel_size + 1 + input_width = size(input, dim=2) + output_width = self % width ! Note: output_width = input_width - kernel_size + 1 !--- Compute the local gradient gdz = (dL/dy) * sigma'(z) for each output. - do j = 1, output_width - gdz(:, j) = gradient(:, j) * self % activation % eval_prime(self % z(:, j)) - end do + gdz = gradient * self % activation % eval_prime(self % z) !--- Compute bias gradients: db(n) = sum_j gdz(n, j) - do n = 1, self % filters - db_local(n) = sum(gdz(n, :), dim=1) - end do + db_local = sum(gdz, dim=2) !--- Initialize weight gradient and input gradient accumulators. dw_local = 0.0 @@ -124,16 +120,16 @@ pure module subroutine backward(self, input, gradient) ! In the forward pass the window for output index j was: ! iws = j, iwe = j + kernel_size - 1. do n = 1, self % filters - do j = 1, output_width - iws = j - iwe = j + self % kernel_size - 1 - do k = 1, self % channels - ! Weight gradient: accumulate contribution from the input window. - dw_local(n, k, :) = dw_local(n, k, :) + input(k, iws:iwe) * gdz(n, j) - ! Input gradient: propagate gradient back to the input window. - self % gradient(k, iws:iwe) = self % gradient(k, iws:iwe) + self % kernel(n, k, :) * gdz(n, j) - end do - end do + do j = 1, output_width + iws = j + iwe = j + self % kernel_size - 1 + do k = 1, self % channels + ! Weight gradient: accumulate contribution from the input window. + dw_local(n,k,:) = dw_local(n,k,:) + input(k,iws:iwe) * gdz(n,j) + ! Input gradient: propagate gradient back to the input window. + self % gradient(k,iws:iwe) = self % gradient(k,iws:iwe) + self % kernel(n,k,:) * gdz(n,j) + end do + end do end do !--- Update stored gradients. From a4866484aa0b8c94836da0973fc0b79a06a905fc Mon Sep 17 00:00:00 2001 From: milancurcic <caomaco@gmail.com> Date: Thu, 13 Mar 2025 22:02:05 -0400 Subject: [PATCH 26/26] Acknowledge contributors --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 50824acb..75da6525 100644 --- a/README.md +++ b/README.md @@ -269,7 +269,9 @@ Thanks to all open-source contributors to neural-fortran: [jvdp1](https://github.com/jvdp1), [jvo203](https://github.com/jvo203), [milancurcic](https://github.com/milancurcic), +[OneAdder](https://github.com/OneAdder), [pirpyn](https://github.com/pirpyn), +[rico07](https://github.com/ricor07), [rouson](https://github.com/rouson), [rweed](https://github.com/rweed), [Spnetic-5](https://github.com/Spnetic-5),