diff --git a/src/test/unit/test_constructors.pf b/src/test/unit/test_constructors.pf index 2b74432f..7e6c160e 100644 --- a/src/test/unit/test_constructors.pf +++ b/src/test/unit/test_constructors.pf @@ -81,12 +81,21 @@ subroutine test_torch_tensor_zeros() ! Check that the tensor values are all zero expected(:,:) = 0.0 test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_zeros") - @assertTrue(test_pass) - @assertEqual(shape(out_data), shape(expected)) + if (.not. test_pass) then + call clean_up() + print *, "Error :: incorrect output from torch_tensor_zeros subroutine" + stop 999 + end if - ! Cleanup - nullify(out_data) - call torch_tensor_delete(tensor) + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(tensor) + end subroutine clean_up end subroutine test_torch_tensor_zeros @@ -132,12 +141,21 @@ subroutine test_torch_tensor_ones() ! Check that the tensor values are all one expected(:,:) = 1.0 test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_ones") - @assertTrue(test_pass) - @assertEqual(shape(out_data), shape(expected)) + if (.not. test_pass) then + call clean_up() + print *, "Error :: incorrect output from torch_tensor_ones subroutine" + stop 999 + end if - ! Cleanup - nullify(out_data) - call torch_tensor_delete(tensor) + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(tensor) + end subroutine clean_up end subroutine test_torch_tensor_ones @@ -184,12 +202,21 @@ subroutine test_torch_from_array_1d() ! Compare the data in the tensor to the input data expected(:) = in_data test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_from_array") - @assertTrue(test_pass) - @assertEqual(shape(out_data), shape(expected)) + if (.not. test_pass) then + call clean_up() + print *, "Error :: incorrect output from torch_tensor_from_array subroutine" + stop 999 + end if - ! Cleanup - nullify(out_data) - call torch_tensor_delete(tensor) + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(tensor) + end subroutine clean_up end subroutine test_torch_from_array_1d @@ -236,12 +263,21 @@ subroutine test_torch_from_array_2d() ! Compare the data in the tensor to the input data expected(:,:) = in_data test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_from_array") - @assertTrue(test_pass) - @assertEqual(shape(out_data), shape(expected)) + if (.not. test_pass) then + call clean_up() + print *, "Error :: incorrect output from torch_tensor_from_array subroutine" + stop 999 + end if - ! Cleanup - nullify(out_data) - call torch_tensor_delete(tensor) + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(tensor) + end subroutine clean_up end subroutine test_torch_from_array_2d @@ -286,12 +322,21 @@ subroutine test_torch_from_array_3d() ! Compare the data in the tensor to the input data expected(:,:,:) = in_data test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_from_array") - @assertTrue(test_pass) - @assertEqual(shape(out_data), shape(expected)) + if (.not. test_pass) then + call clean_up() + print *, "Error :: incorrect output from torch_tensor_from_array subroutine" + stop 999 + end if - ! Cleanup - nullify(out_data) - call torch_tensor_delete(tensor) + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(tensor) + end subroutine clean_up end subroutine test_torch_from_array_3d @@ -340,11 +385,20 @@ subroutine test_torch_from_blob() ! Compare the data in the tensor to the input data expected(:,:) = in_data test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_from_blob") - @assertTrue(test_pass) - @assertEqual(shape(out_data), shape(expected)) + if (.not. test_pass) then + call clean_up() + print *, "Error :: incorrect output from torch_tensor_from_array subroutine" + stop 999 + end if - ! Cleanup - nullify(out_data) - call torch_tensor_delete(tensor) + call clean_up() + + contains + + ! Subroutine for freeing memory and nullifying pointers used in the unit test + subroutine clean_up() + nullify(out_data) + call torch_tensor_delete(tensor) + end subroutine clean_up end subroutine test_torch_from_blob