forked from modern-fortran/neural-fortran
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnf_loss.f90
101 lines (87 loc) · 3.15 KB
/
nf_loss.f90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
module nf_loss
!! This module provides a collection of loss functions and their derivatives.
!! The implementation is based on an abstract loss derived type
!! which has the required eval and derivative methods.
!! An implementation of a new loss type thus requires writing a concrete
!! loss type that extends the abstract loss derived type, and that
!! implements concrete eval and derivative methods that accept vectors.
use nf_metrics, only: metric_type
implicit none
private
public :: loss_type
public :: mse
public :: quadratic
type, extends(metric_type), abstract :: loss_type
contains
procedure(loss_derivative_interface), nopass, deferred :: derivative
end type loss_type
abstract interface
pure function loss_derivative_interface(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res(size(true))
end function loss_derivative_interface
end interface
type, extends(loss_type) :: mse
!! Mean Square Error loss function
contains
procedure, nopass :: eval => mse_eval
procedure, nopass :: derivative => mse_derivative
end type mse
type, extends(loss_type) :: quadratic
!! Quadratic loss function
contains
procedure, nopass :: eval => quadratic_eval
procedure, nopass :: derivative => quadratic_derivative
end type quadratic
interface
pure module function mse_eval(true, predicted) result(res)
!! Mean Square Error loss function:
!!
!! L = sum((predicted - true)**2) / size(true)
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res
!! Resulting loss value
end function mse_eval
pure module function mse_derivative(true, predicted) result(res)
!! First derivative of the Mean Square Error loss function:
!!
!! L = 2 * (predicted - true) / size(true)
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res(size(true))
!! Resulting loss values
end function mse_derivative
pure module function quadratic_eval(true, predicted) result(res)
!! Quadratic loss function:
!!
!! L = sum((predicted - true)**2) / 2
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res
!! Resulting loss value
end function quadratic_eval
pure module function quadratic_derivative(true, predicted) result(res)
!! First derivative of the quadratic loss function:
!!
!! L' = predicted - true
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res(size(true))
!! Resulting loss values
end function quadratic_derivative
end interface
end module nf_loss