Skip to content

Commit

Permalink
Addition of the Loss derived type and of the MSE loss function (#175)
Browse files Browse the repository at this point in the history
* Addition of the abstract DT loss_type and of the DT quadratic

* Support of the loss_type for the derivative loss function

* Addition of the MSE loss function

* add documentation

* Test program placeholder

* Add loss test to CMake config

* Minimal test for expected values

* Bump version and copyright years

---------

Co-authored-by: Vandenplas, Jeremie <[email protected]>
Co-authored-by: milancurcic <[email protected]>
  • Loading branch information
3 people authored Apr 19, 2024
1 parent cf47114 commit f7b6006
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 21 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2018-2023 neural-fortran contributors
Copyright (c) 2018-2024 neural-fortran contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
4 changes: 2 additions & 2 deletions fpm.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
name = "neural-fortran"
version = "0.15.1"
version = "0.16.0"
license = "MIT"
author = "Milan Curcic"
maintainer = "[email protected]"
copyright = "Copyright 2018-2023, neural-fortran contributors"
copyright = "Copyright 2018-2024, neural-fortran contributors"

[build]
external-modules = "hdf5"
Expand Down
1 change: 1 addition & 0 deletions src/nf.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module nf
use nf_layer, only: layer
use nf_layer_constructors, only: &
conv2d, dense, flatten, input, maxpool2d, reshape
use nf_loss, only: mse, quadratic
use nf_network, only: network
use nf_optimizers, only: sgd, rmsprop, adam, adagrad
use nf_activation, only: activation_function, elu, exponential, &
Expand Down
80 changes: 72 additions & 8 deletions src/nf/nf_loss.f90
Original file line number Diff line number Diff line change
@@ -1,28 +1,92 @@
module nf_loss

!! This module will eventually provide a collection of loss functions and
!! their derivatives. For the time being it provides only the quadratic
!! function.
!! This module provides a collection of loss functions and their derivatives.
!! The implementation is based on an abstract loss derived type
!! which has the required eval and derivative methods.
!! An implementation of a new loss type thus requires writing a concrete
!! loss type that extends the abstract loss derived type, and that
!! implements concrete eval and derivative methods that accept vectors.

implicit none

private
public :: quadratic, quadratic_derivative
public :: loss_type
public :: mse
public :: quadratic

type, abstract :: loss_type
contains
procedure(loss_interface), nopass, deferred :: eval
procedure(loss_derivative_interface), nopass, deferred :: derivative
end type loss_type

abstract interface
pure function loss_interface(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res
end function loss_interface
pure function loss_derivative_interface(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res(size(true))
end function loss_derivative_interface
end interface

type, extends(loss_type) :: mse
!! Mean Square Error loss function
contains
procedure, nopass :: eval => mse_eval
procedure, nopass :: derivative => mse_derivative
end type mse

type, extends(loss_type) :: quadratic
!! Quadratic loss function
contains
procedure, nopass :: eval => quadratic_eval
procedure, nopass :: derivative => quadratic_derivative
end type quadratic

interface

pure module function quadratic(true, predicted) result(res)
!! Quadratic loss function:
pure module function mse_eval(true, predicted) result(res)
!! Mean Square Error loss function:
!!
!! L = sum((predicted - true)**2) / size(true)
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res
!! Resulting loss value
end function mse_eval

pure module function mse_derivative(true, predicted) result(res)
!! First derivative of the Mean Square Error loss function:
!!
!! L = (predicted - true)**2 / 2
!! L = 2 * (predicted - true) / size(true)
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res(size(true))
!! Resulting loss values
end function quadratic
end function mse_derivative

pure module function quadratic_eval(true, predicted) result(res)
!! Quadratic loss function:
!!
!! L = sum((predicted - true)**2) / 2
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res
!! Resulting loss value
end function quadratic_eval

pure module function quadratic_derivative(true, predicted) result(res)
!! First derivative of the quadratic loss function:
Expand Down
22 changes: 18 additions & 4 deletions src/nf/nf_loss_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

contains

pure module function quadratic(true, predicted) result(res)
pure module function quadratic_eval(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res(size(true))
res = (predicted - true)**2 / 2
end function quadratic
real :: res
res = sum((predicted - true)**2) / 2
end function quadratic_eval

pure module function quadratic_derivative(true, predicted) result(res)
real, intent(in) :: true(:)
Expand All @@ -18,4 +18,18 @@ pure module function quadratic_derivative(true, predicted) result(res)
res = predicted - true
end function quadratic_derivative

pure module function mse_eval(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res
res = sum((predicted - true)**2) / size(true)
end function mse_eval

pure module function mse_derivative(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res(size(true))
res = 2 * (predicted - true) / size(true)
end function mse_derivative

end submodule nf_loss_submodule
10 changes: 8 additions & 2 deletions src/nf/nf_network.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module nf_network
!! This module provides the network type to create new models.

use nf_layer, only: layer
use nf_loss, only: loss_type
use nf_optimizers, only: optimizer_base_type

implicit none
Expand All @@ -13,6 +14,7 @@ module nf_network
type :: network

type(layer), allocatable :: layers(:)
class(loss_type), allocatable :: loss
class(optimizer_base_type), allocatable :: optimizer

contains
Expand Down Expand Up @@ -138,7 +140,7 @@ end function predict_batch_3d

interface

pure module subroutine backward(self, output)
pure module subroutine backward(self, output, loss)
!! Apply one backward pass through the network.
!! This changes the state of layers on the network.
!! Typically used only internally from the `train` method,
Expand All @@ -147,6 +149,8 @@ pure module subroutine backward(self, output)
!! Network instance
real, intent(in) :: output(:)
!! Output data
class(loss_type), intent(in), optional :: loss
!! Loss instance to use. If not provided, the default is quadratic().
end subroutine backward

pure module integer function get_num_params(self)
Expand Down Expand Up @@ -185,7 +189,7 @@ module subroutine print_info(self)
end subroutine print_info

module subroutine train(self, input_data, output_data, batch_size, &
epochs, optimizer)
epochs, optimizer, loss)
class(network), intent(in out) :: self
!! Network instance
real, intent(in) :: input_data(:,:)
Expand All @@ -204,6 +208,8 @@ module subroutine train(self, input_data, output_data, batch_size, &
!! Number of epochs to run
class(optimizer_base_type), intent(in), optional :: optimizer
!! Optimizer instance to use. If not provided, the default is sgd().
class(loss_type), intent(in), optional :: loss
!! Loss instance to use. If not provided, the default is quadratic().
end subroutine train

module subroutine update(self, optimizer, batch_size)
Expand Down
33 changes: 29 additions & 4 deletions src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
use nf_keras, only: get_keras_h5_layers, keras_layer
use nf_layer, only: layer
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
use nf_loss, only: quadratic_derivative
use nf_loss, only: quadratic
use nf_optimizers, only: optimizer_base_type, sgd
use nf_parallel, only: tile_indices
use nf_activation, only: activation_function, &
Expand Down Expand Up @@ -280,11 +280,27 @@ pure function get_activation_by_name(activation_name) result(res)

end function get_activation_by_name

pure module subroutine backward(self, output)
pure module subroutine backward(self, output, loss)
class(network), intent(in out) :: self
real, intent(in) :: output(:)
class(loss_type), intent(in), optional :: loss
integer :: n, num_layers

! Passing the loss instance is optional. If not provided, and if the
! loss instance has not already been set, we default to the default quadratic. The
! instantiation and initialization below of the loss instance is normally done
! at the beginning of the network % train() method. However, if the user
! wants to call network % backward() directly, for example if they use their
! own custom mini-batching routine, we initialize the loss instance here as
! well. If it's initialized already, this step is a cheap no-op.
if (.not. allocated(self % loss)) then
if (present(loss)) then
self % loss = loss
else
self % loss = quadratic()
end if
end if

num_layers = size(self % layers)

! Iterate backward over layers, from the output layer
Expand All @@ -297,7 +313,7 @@ pure module subroutine backward(self, output)
type is(dense_layer)
call self % layers(n) % backward( &
self % layers(n - 1), &
quadratic_derivative(output, this_layer % output) &
self % loss % derivative(output, this_layer % output) &
)
end select
else
Expand Down Expand Up @@ -542,13 +558,14 @@ end subroutine set_params


module subroutine train(self, input_data, output_data, batch_size, &
epochs, optimizer)
epochs, optimizer, loss)
class(network), intent(in out) :: self
real, intent(in) :: input_data(:,:)
real, intent(in) :: output_data(:,:)
integer, intent(in) :: batch_size
integer, intent(in) :: epochs
class(optimizer_base_type), intent(in), optional :: optimizer
class(loss_type), intent(in), optional :: loss
class(optimizer_base_type), allocatable :: optimizer_

real :: pos
Expand All @@ -567,6 +584,14 @@ module subroutine train(self, input_data, output_data, batch_size, &

call self % optimizer % init(self % get_num_params())

! Passing the loss instance is optional.
! If not provided, we default to quadratic().
if (present(loss)) then
self % loss = loss
else
self % loss = quadratic()
end if

dataset_size = size(output_data, dim=2)

epoch_loop: do n = 1, epochs
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ foreach(execid
cnn_from_keras
conv2d_network
optimizers
loss
)
add_executable(test_${execid} test_${execid}.f90)
target_link_libraries(test_${execid} PRIVATE neural h5fortran::h5fortran jsonfortran::jsonfortran ${LIBS})
Expand Down
53 changes: 53 additions & 0 deletions test/test_loss.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
program test_loss

use iso_fortran_env, only: stderr => error_unit
use nf, only: mse, quadratic

implicit none

logical :: ok = .true.

block

type(mse) :: loss
real :: true(2) = [1., 2.]
real :: pred(2) = [3., 4.]

if (.not. loss % eval(true, pred) == 4) then
write(stderr, '(a)') 'expected output of mse % eval().. failed'
ok = .false.
end if

if (.not. all(loss % derivative(true, pred) == [2, 2])) then
write(stderr, '(a)') 'expected output of mse % derivative().. failed'
ok = .false.
end if

end block

block

type(quadratic) :: loss
real :: true(4) = [1., 2., 3., 4.]
real :: pred(4) = [3., 4., 5., 6.]

if (.not. loss % eval(true, pred) == 8) then
write(stderr, '(a)') 'expected output of quadratic % eval().. failed'
ok = .false.
end if

if (.not. all(loss % derivative(true, pred) == [2, 2, 2, 2])) then
write(stderr, '(a)') 'expected output of quadratic % derivative().. failed'
ok = .false.
end if

end block

if (ok) then
print '(a)', 'test_loss: All tests passed.'
else
write(stderr, '(a)') 'test_loss: One or more tests failed.'
stop 1
end if

end program test_loss

0 comments on commit f7b6006

Please sign in to comment.