Skip to content

Commit

Permalink
Call cleanup if assertion fails in constructor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jwallwork23 committed Jan 14, 2025
1 parent 0d18ee8 commit 325768b
Showing 1 changed file with 84 additions and 30 deletions.
114 changes: 84 additions & 30 deletions src/test/unit/test_constructors.pf
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,21 @@ subroutine test_torch_tensor_zeros()
! Check that the tensor values are all zero
expected(:,:) = 0.0
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_zeros")
@assertTrue(test_pass)
@assertEqual(shape(out_data), shape(expected))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from torch_tensor_zeros subroutine"
stop 999
end if

! Cleanup
nullify(out_data)
call torch_tensor_delete(tensor)
call clean_up()

contains

! Subroutine for freeing memory and nullifying pointers used in the unit test
subroutine clean_up()
nullify(out_data)
call torch_tensor_delete(tensor)
end subroutine clean_up

end subroutine test_torch_tensor_zeros

Expand Down Expand Up @@ -132,12 +141,21 @@ subroutine test_torch_tensor_ones()
! Check that the tensor values are all one
expected(:,:) = 1.0
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_ones")
@assertTrue(test_pass)
@assertEqual(shape(out_data), shape(expected))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from torch_tensor_ones subroutine"
stop 999
end if

! Cleanup
nullify(out_data)
call torch_tensor_delete(tensor)
call clean_up()

contains

! Subroutine for freeing memory and nullifying pointers used in the unit test
subroutine clean_up()
nullify(out_data)
call torch_tensor_delete(tensor)
end subroutine clean_up

end subroutine test_torch_tensor_ones

Expand Down Expand Up @@ -184,12 +202,21 @@ subroutine test_torch_from_array_1d()
! Compare the data in the tensor to the input data
expected(:) = in_data
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_from_array")
@assertTrue(test_pass)
@assertEqual(shape(out_data), shape(expected))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from torch_tensor_from_array subroutine"
stop 999
end if

! Cleanup
nullify(out_data)
call torch_tensor_delete(tensor)
call clean_up()

contains

! Subroutine for freeing memory and nullifying pointers used in the unit test
subroutine clean_up()
nullify(out_data)
call torch_tensor_delete(tensor)
end subroutine clean_up

end subroutine test_torch_from_array_1d

Expand Down Expand Up @@ -236,12 +263,21 @@ subroutine test_torch_from_array_2d()
! Compare the data in the tensor to the input data
expected(:,:) = in_data
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_from_array")
@assertTrue(test_pass)
@assertEqual(shape(out_data), shape(expected))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from torch_tensor_from_array subroutine"
stop 999
end if

! Cleanup
nullify(out_data)
call torch_tensor_delete(tensor)
call clean_up()

contains

! Subroutine for freeing memory and nullifying pointers used in the unit test
subroutine clean_up()
nullify(out_data)
call torch_tensor_delete(tensor)
end subroutine clean_up

end subroutine test_torch_from_array_2d

Expand Down Expand Up @@ -286,12 +322,21 @@ subroutine test_torch_from_array_3d()
! Compare the data in the tensor to the input data
expected(:,:,:) = in_data
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_from_array")
@assertTrue(test_pass)
@assertEqual(shape(out_data), shape(expected))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from torch_tensor_from_array subroutine"
stop 999
end if

! Cleanup
nullify(out_data)
call torch_tensor_delete(tensor)
call clean_up()

contains

! Subroutine for freeing memory and nullifying pointers used in the unit test
subroutine clean_up()
nullify(out_data)
call torch_tensor_delete(tensor)
end subroutine clean_up

end subroutine test_torch_from_array_3d

Expand Down Expand Up @@ -340,11 +385,20 @@ subroutine test_torch_from_blob()
! Compare the data in the tensor to the input data
expected(:,:) = in_data
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_from_blob")
@assertTrue(test_pass)
@assertEqual(shape(out_data), shape(expected))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from torch_tensor_from_array subroutine"
stop 999
end if

! Cleanup
nullify(out_data)
call torch_tensor_delete(tensor)
call clean_up()

contains

! Subroutine for freeing memory and nullifying pointers used in the unit test
subroutine clean_up()
nullify(out_data)
call torch_tensor_delete(tensor)
end subroutine clean_up

end subroutine test_torch_from_blob

0 comments on commit 325768b

Please sign in to comment.