Skip to content

Commit

Permalink
Call cleanup if assertion fails in operator overload tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jwallwork23 committed Jan 14, 2025
1 parent 49b0717 commit 0d18ee8
Showing 1 changed file with 55 additions and 22 deletions.
77 changes: 55 additions & 22 deletions src/test/unit/test_tensor_operator_overloads.pf
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@ subroutine test_torch_tensor_assign()
! array
call torch_tensor_to_array(tensor2, out_data, shape(in_data))
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_assign")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from overloaded assignment operator"
stop 999
end if

call clean_up()

Expand Down Expand Up @@ -130,8 +133,11 @@ subroutine test_torch_tensor_add()
call torch_tensor_to_array(tensor3, out_data, shape(in_data1))
expected(:,:) = in_data1 + in_data2
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_add")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from overloaded addition operator"
stop 999
end if

call clean_up()

Expand Down Expand Up @@ -207,8 +213,11 @@ subroutine test_torch_tensor_subtract()
call torch_tensor_to_array(tensor3, out_data, shape(in_data1))
expected(:,:) = in_data1 - in_data2
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_subtract")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from overloaded subtraction operator"
stop 999
end if

call clean_up()

Expand Down Expand Up @@ -284,8 +293,11 @@ subroutine test_torch_tensor_multiply()
call torch_tensor_to_array(tensor3, out_data, shape(in_data1))
expected(:,:) = in_data1 * in_data2
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_multiply")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from overloaded multiplication operator"
stop 999
end if

call clean_up()

Expand Down Expand Up @@ -368,11 +380,17 @@ subroutine test_torch_tensor_scalar_multiply()
call torch_tensor_to_array(tensor3, out_data3, shape(in_data))
expected(:,:) = scalar * in_data
test_pass = assert_allclose(out_data2, expected, test_name="test_torch_tensor_premultiply")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data2))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from overloaded post-multiplication operator"
stop 999
end if
test_pass = assert_allclose(out_data3, expected, test_name="test_torch_tensor_postmultiply")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data3))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from overloaded pre-multiplication operator"
stop 999
end if

call clean_up()

Expand Down Expand Up @@ -449,8 +467,11 @@ subroutine test_torch_tensor_divide()
call torch_tensor_to_array(tensor3, out_data, shape(in_data1))
expected(:,:) = in_data1 / in_data2
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_divide")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from overloaded division operator"
stop 999
end if

call clean_up()

Expand Down Expand Up @@ -519,8 +540,11 @@ subroutine test_torch_tensor_scalar_divide()
call torch_tensor_to_array(tensor2, out_data, shape(in_data))
expected(:,:) = in_data / scalar
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_postdivide")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from overloaded scalar division operator"
stop 999
end if

call clean_up()

Expand Down Expand Up @@ -601,11 +625,17 @@ subroutine test_torch_tensor_square()
call torch_tensor_to_array(tensor3, out_data3, shape(in_data))
expected(:,:) = in_data ** 2
test_pass = assert_allclose(out_data2, expected, test_name="test_torch_tensor_square_int")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data2))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from overloaded integer exponentation operator"
stop 999
end if
test_pass = assert_allclose(out_data3, expected, test_name="test_torch_tensor_square_float")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data3))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from overloaded floating point exponentation operator"
stop 999
end if

call clean_up()

Expand Down Expand Up @@ -673,8 +703,11 @@ subroutine test_torch_tensor_sqrt()
call torch_tensor_to_array(tensor2, out_data, shape(in_data))
expected(:,:) = in_data ** 0.5
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_sqrt")
@assertTrue(test_pass)
@assertEqual(shape(expected), shape(out_data))
if (.not. test_pass) then
call clean_up()
print *, "Error :: incorrect output from overloaded square root operator"
stop 999
end if

call clean_up()

Expand Down

0 comments on commit 0d18ee8

Please sign in to comment.