Skip to content

Commit

Permalink
Add simple tests for metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
milancurcic committed Jun 13, 2024
1 parent 0563395 commit f506901
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ foreach(execid
conv2d_network
optimizers
loss
metrics
)
add_executable(test_${execid} test_${execid}.f90)
target_link_libraries(test_${execid} PRIVATE neural-fortran h5fortran::h5fortran jsonfortran::jsonfortran ${LIBS})
Expand Down
70 changes: 70 additions & 0 deletions test/test_metrics.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
program test_metrics
use iso_fortran_env, only: stderr => error_unit
use nf, only: dense, input, network, sgd, mse
implicit none
type(network) :: net
logical :: ok = .true.

! Minimal 2-layer network
net = network([ &
input(1), &
dense(1) &
])

training: block
real :: x(1), y(1)
real :: tolerance = 1e-3
integer :: n
integer, parameter :: num_iterations = 1000
real :: quadratic_loss, mse_metric
real, allocatable :: metrics(:,:)

x = [0.1234567]
y = [0.7654321]

do n = 1, num_iterations
call net % forward(x)
call net % backward(y)
call net % update(sgd(learning_rate=1.))
if (all(abs(net % predict(x) - y) < tolerance)) exit
end do

! Returns only one metric, based on the default loss function (quadratic).
metrics = net % evaluate(reshape(x, [1, 1]), reshape(y, [1, 1]))
quadratic_loss = metrics(1,1)

if (.not. all(shape(metrics) == [1, 1])) then
write(stderr, '(a)') 'metrics array is the correct shape (1, 1).. failed'
ok = .false.
end if

! Returns two metrics, one from the loss function and another specified by the user.
metrics = net % evaluate(reshape(x, [1, 1]), reshape(y, [1, 1]), metric=mse())

if (.not. all(shape(metrics) == [1, 2])) then
write(stderr, '(a)') 'metrics array is the correct shape (1, 2).. failed'
ok = .false.
end if

mse_metric = metrics(1,2)

if (.not. all(metrics < 1e-5)) then
write(stderr, '(a)') 'value for all metrics is expected.. failed'
ok = .false.
end if

if (.not. metrics(1,1) == quadratic_loss) then
write(stderr, '(a)') 'first metric should be the same as that of the loss function.. failed'
ok = .false.
end if

end block training

if (ok) then
print '(a)', 'test_metrics: All tests passed.'
else
write(stderr, '(a)') 'test_metrics: One or more tests failed.'
stop 1
end if

end program test_metrics

0 comments on commit f506901

Please sign in to comment.