diff --git a/CMakeLists.txt b/CMakeLists.txt index bd48e99..0fd1be6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,8 @@ add_library(neural src/nf.f90 src/nf/nf_activation.f90 src/nf/nf_base_layer.f90 + src/nf/nf_batchnorm_layer.f90 + src/nf/nf_batchnorm_layer_submodule.f90 src/nf/nf_conv2d_layer.f90 src/nf/nf_conv2d_layer_submodule.f90 src/nf/nf_datasets.f90 diff --git a/src/nf.f90 b/src/nf.f90 index eb2a903..b6b90b4 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -3,7 +3,7 @@ module nf use nf_datasets_mnist, only: label_digits, load_mnist use nf_layer, only: layer use nf_layer_constructors, only: & - conv2d, dense, flatten, input, maxpool2d, reshape + batchnorm, conv2d, dense, flatten, input, maxpool2d, reshape use nf_network, only: network use nf_optimizers, only: sgd, rmsprop, adam, adagrad use nf_activation, only: activation_function, elu, exponential, & diff --git a/src/nf/nf_batchnorm_layer.f90 b/src/nf/nf_batchnorm_layer.f90 new file mode 100644 index 0000000..193d5ef --- /dev/null +++ b/src/nf/nf_batchnorm_layer.f90 @@ -0,0 +1,109 @@ +module nf_batchnorm_layer + + !! This module provides a batch normalization `batchnorm_layer` type. + + use nf_base_layer, only: base_layer + implicit none + + private + public :: batchnorm_layer + + type, extends(base_layer) :: batchnorm_layer + + integer :: num_features + real, allocatable :: gamma(:) + real, allocatable :: beta(:) + real, allocatable :: running_mean(:) + real, allocatable :: running_var(:) + real, allocatable :: input(:,:) + real, allocatable :: output(:,:) + real, allocatable :: gamma_grad(:) + real, allocatable :: beta_grad(:) + real, allocatable :: input_grad(:,:) + real :: epsilon = 1e-5 + + contains + + procedure :: forward + procedure :: backward + procedure :: get_gradients + procedure :: get_num_params + procedure :: get_params + procedure :: init + procedure :: set_params + + end type batchnorm_layer + + interface batchnorm_layer + pure module function batchnorm_layer_cons(num_features) result(res) + !! `batchnorm_layer` constructor function + integer, intent(in) :: num_features + type(batchnorm_layer) :: res + end function batchnorm_layer_cons + end interface batchnorm_layer + + interface + + module subroutine init(self, input_shape) + !! Initialize the layer data structures. + !! + !! This is a deferred procedure from the `base_layer` abstract type. + class(batchnorm_layer), intent(in out) :: self + !! A `batchnorm_layer` instance + integer, intent(in) :: input_shape(:) + !! Input layer dimensions + end subroutine init + + pure module subroutine forward(self, input) + !! Apply a forward pass on the `batchnorm_layer`. + class(batchnorm_layer), intent(in out) :: self + !! A `batchnorm_layer` instance + real, intent(in) :: input(:,:) + !! Input data + end subroutine forward + + pure module subroutine backward(self, input, gradient) + !! Apply a backward pass on the `batchnorm_layer`. + class(batchnorm_layer), intent(in out) :: self + !! A `batchnorm_layer` instance + real, intent(in) :: input(:,:) + !! Input data (previous layer) + real, intent(in) :: gradient(:,:) + !! Gradient (next layer) + end subroutine backward + + pure module function get_num_params(self) result(num_params) + !! Get the number of parameters in the layer. + class(batchnorm_layer), intent(in) :: self + !! A `batchnorm_layer` instance + integer :: num_params + !! Number of parameters + end function get_num_params + + pure module function get_params(self) result(params) + !! Return the parameters (gamma, beta, running_mean, running_var) of this layer. + class(batchnorm_layer), intent(in) :: self + !! A `batchnorm_layer` instance + real, allocatable :: params(:) + !! Parameters to get + end function get_params + + pure module function get_gradients(self) result(gradients) + !! Return the gradients of this layer. + class(batchnorm_layer), intent(in) :: self + !! A `batchnorm_layer` instance + real, allocatable :: gradients(:) + !! Gradients to get + end function get_gradients + + module subroutine set_params(self, params) + !! Set the parameters of the layer. + class(batchnorm_layer), intent(in out) :: self + !! A `batchnorm_layer` instance + real, intent(in) :: params(:) + !! Parameters to set + end subroutine set_params + + end interface + +end module nf_batchnorm_layer diff --git a/src/nf/nf_batchnorm_layer_submodule.f90 b/src/nf/nf_batchnorm_layer_submodule.f90 new file mode 100644 index 0000000..9f3d2a8 --- /dev/null +++ b/src/nf/nf_batchnorm_layer_submodule.f90 @@ -0,0 +1,105 @@ +submodule(nf_batchnorm_layer) nf_batchnorm_layer_submodule + + implicit none + +contains + + pure module function batchnorm_layer_cons(num_features) result(res) + implicit none + integer, intent(in) :: num_features + type(batchnorm_layer) :: res + + res % num_features = num_features + allocate(res % gamma(num_features), source=1.0) + allocate(res % beta(num_features)) + allocate(res % running_mean(num_features), source=0.0) + allocate(res % running_var(num_features), source=1.0) + allocate(res % input(num_features, num_features)) + allocate(res % output(num_features, num_features)) + allocate(res % gamma_grad(num_features)) + allocate(res % beta_grad(num_features)) + allocate(res % input_grad(num_features, num_features)) + + end function batchnorm_layer_cons + + module subroutine init(self, input_shape) + implicit none + class(batchnorm_layer), intent(in out) :: self + integer, intent(in) :: input_shape(:) + + self % input = 0 + self % output = 0 + + ! Initialize gamma, beta, running_mean, and running_var + self % gamma = 1.0 + self % beta = 0.0 + self % running_mean = 0.0 + self % running_var = 1.0 + + end subroutine init + + pure module subroutine forward(self, input) + implicit none + class(batchnorm_layer), intent(in out) :: self + real, intent(in) :: input(:,:) + + ! Store input for backward pass + self % input = input + + associate( & + ! Normalize the input + normalized_input => (input - reshape(self % running_mean, shape(input, 1))) & + / sqrt(reshape(self % running_var, shape(input, 1)) + self % epsilon) & + ) + + ! Batch normalization forward pass + self % output = reshape(self % gamma, shape(input, 1)) * normalized_input & + + reshape(self % beta, shape(input, 1)) + + end associate + + end subroutine forward + + pure module subroutine backward(self, input, gradient) + implicit none + class(batchnorm_layer), intent(in out) :: self + real, intent(in) :: input(:,:) + real, intent(in) :: gradient(:,:) + + ! Calculate gradients for gamma, beta + self % gamma_grad = sum(gradient * (input - reshape(self % running_mean, shape(input, 1))) & + / sqrt(reshape(self % running_var, shape(input, 1)) + self % epsilon), dim=2) + self % beta_grad = sum(gradient, dim=2) + + ! Calculate gradients for input + self % input_grad = gradient * reshape(self % gamma, shape(input, 1)) & + / sqrt(reshape(self % running_var, shape(input, 1)) + self % epsilon) + + end subroutine backward + + pure module function get_num_params(self) result(num_params) + class(batchnorm_layer), intent(in) :: self + integer :: num_params + num_params = 2 * self % num_features + end function get_num_params + + pure module function get_params(self) result(params) + class(batchnorm_layer), intent(in) :: self + real, allocatable :: params(:) + params = [self % gamma, self % beta] + end function get_params + + pure module function get_gradients(self) result(gradients) + class(batchnorm_layer), intent(in) :: self + real, allocatable :: gradients(:) + gradients = [self % gamma_grad, self % beta_grad] + end function get_gradients + + module subroutine set_params(self, params) + class(batchnorm_layer), intent(in out) :: self + real, intent(in) :: params(:) + self % gamma = params(1:self % num_features) + self % beta = params(self % num_features+1:2*self % num_features) + end subroutine set_params + +end submodule nf_batchnorm_layer_submodule diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index ce9a724..b036f1b 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -8,7 +8,7 @@ module nf_layer_constructors implicit none private - public :: conv2d, dense, flatten, input, maxpool2d, reshape + public :: batchnorm, conv2d, dense, flatten, input, maxpool2d, reshape interface input @@ -106,6 +106,25 @@ pure module function flatten() result(res) !! Resulting layer instance end function flatten + pure module function batchnorm(num_features) result(res) + !! Batch normalization layer constructor. + !! + !! This layer is for adding batch normalization to the network. + !! A batch normalization layer can be used after conv2d or dense layers. + !! + !! Example: + !! + !! ``` + !! use nf, only :: batchnorm, layer + !! type(layer) :: batchnorm_layer + !! batchnorm_layer = batchnorm(num_features = 64) + !! ``` + integer, intent(in) :: num_features + !! Number of features in the Layer + type(layer) :: res + !! Resulting layer instance + end function batchnorm + pure module function conv2d(filters, kernel_size, activation) result(res) !! 2-d convolutional layer constructor. !! diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index 002a83b..914df2f 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -1,6 +1,7 @@ submodule(nf_layer_constructors) nf_layer_constructors_submodule use nf_layer, only: layer + use nf_batchnorm_layer, only: batchnorm_layer use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer use nf_flatten_layer, only: flatten_layer @@ -14,6 +15,13 @@ contains + pure module function batchnorm(num_features) result(res) + integer, intent(in) :: num_features + type(layer) :: res + res % name = 'batchnorm' + allocate(res % p, source=batchnorm_layer(num_features)) + end function batchnorm + pure module function conv2d(filters, kernel_size, activation) result(res) integer, intent(in) :: filters integer, intent(in) :: kernel_size diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90 index 0746764..94d9d17 100644 --- a/src/nf/nf_layer_submodule.f90 +++ b/src/nf/nf_layer_submodule.f90 @@ -1,6 +1,7 @@ submodule(nf_layer) nf_layer_submodule use iso_fortran_env, only: stderr => error_unit + use nf_batchnorm_layer, only: batchnorm_layer use nf_conv2d_layer, only: conv2d_layer use nf_dense_layer, only: dense_layer use nf_flatten_layer, only: flatten_layer diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 5bafb7c..ecff74d 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -10,7 +10,7 @@ use nf_io_hdf5, only: get_hdf5_dataset use nf_keras, only: get_keras_h5_layers, keras_layer use nf_layer, only: layer - use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape + use nf_layer_constructors, only: batchnorm, conv2d, dense, flatten, input, maxpool2d, reshape use nf_loss, only: quadratic_derivative use nf_optimizers, only: optimizer_base_type, sgd use nf_parallel, only: tile_indices diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 26646ec..b4ee820 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -16,6 +16,7 @@ foreach(execid cnn_from_keras conv2d_network optimizers + batchnorm_layer ) add_executable(test_${execid} test_${execid}.f90) target_link_libraries(test_${execid} PRIVATE neural h5fortran::h5fortran jsonfortran::jsonfortran ${LIBS}) diff --git a/test/test_batchnorm_layer.f90 b/test/test_batchnorm_layer.f90 new file mode 100644 index 0000000..473b22d --- /dev/null +++ b/test/test_batchnorm_layer.f90 @@ -0,0 +1,65 @@ +program test_batchnorm_layer + + use iso_fortran_env, only: stderr => error_unit + use nf, only: batchnorm, layer + use nf_batchnorm_layer, only: batchnorm_layer + + implicit none + + type(layer) :: bn_layer + integer, parameter :: num_features = 64 + real, allocatable :: sample_input(:,:) + real, allocatable :: output(:,:) + real, allocatable :: gradient(:,:) + integer, parameter :: input_shape(1) = [num_features] + real, allocatable :: gamma_grad(:), beta_grad(:) + real, parameter :: tolerance = 1e-7 + logical :: ok = .true. + + bn_layer = batchnorm(num_features) + + if (.not. bn_layer % name == 'batchnorm') then + ok = .false. + write(stderr, '(a)') 'batchnorm layer has its name set correctly.. failed' + end if + + if (bn_layer % initialized) then + ok = .false. + write(stderr, '(a)') 'batchnorm layer should not be marked as initialized yet.. failed' + end if + + ! Initialize sample input and gradient + allocate(sample_input(num_features, 1)) + allocate(gradient(num_features, 1)) + sample_input = 1.0 + gradient = 2.0 + + !TODO run forward and backward passes directly on the batchnorm_layer instance + !TODO since we don't yet support tiying in with the input layer. + + !TODO Retrieve output and check normalization + !call bn_layer % get_output(output) + !if (.not. all(abs(output - sample_input) < tolerance)) then + ! ok = .false. + ! write(stderr, '(a)') 'batchnorm layer output should be close to input.. failed' + !end if + + !TODO Retrieve gamma and beta gradients + !allocate(gamma_grad(num_features)) + !allocate(beta_grad(num_features)) + !call bn_layer % get_gradients(gamma_grad, beta_grad) + + !if (.not. all(beta_grad == sum(gradient))) then + ! ok = .false. + ! write(stderr, '(a)') 'batchnorm layer beta gradients are incorrect.. failed' + !end if + + ! Report test results + if (ok) then + print '(a)', 'test_batchnorm_layer: All tests passed.' + else + write(stderr, '(a)') 'test_batchnorm_layer: One or more tests failed.' + stop 1 + end if + +end program test_batchnorm_layer