From 2da4b7878ded32ca69fffffb915f0d93da7442a4 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Wed, 24 Jul 2024 21:12:22 +0900 Subject: [PATCH 1/6] replace assert with std::runtime_error --- common/src/KokkosFFT_Helpers.hpp | 10 ++++++++-- common/src/KokkosFFT_layouts.hpp | 16 ++++++++++++---- common/src/KokkosFFT_padding.hpp | 10 ++++++++-- common/src/KokkosFFT_utils.hpp | 5 ++++- 4 files changed, 32 insertions(+), 9 deletions(-) diff --git a/common/src/KokkosFFT_Helpers.hpp b/common/src/KokkosFFT_Helpers.hpp index f0a9ec25..76d40edf 100644 --- a/common/src/KokkosFFT_Helpers.hpp +++ b/common/src/KokkosFFT_Helpers.hpp @@ -23,8 +23,14 @@ auto get_shift(const ViewType& inout, axis_type _axes, int direction = 1) { // Assert if the elements are overlapped constexpr int rank = ViewType::rank(); - assert(!KokkosFFT::Impl::has_duplicate_values(axes)); - assert(!KokkosFFT::Impl::is_out_of_range_value_included(axes, rank)); + if (KokkosFFT::Impl::has_duplicate_values(axes)) { + throw std::runtime_error("get_shift: axes are overlapped."); + } + if (KokkosFFT::Impl::is_out_of_range_value_included(axes, rank)) { + throw std::runtime_error( + "get_shift: axes include out of range index." + "axes should be in the range of [-rank, rank-1]."); + } axis_type shift = {0}; for (int i = 0; i < static_cast(DIM); i++) { diff --git a/common/src/KokkosFFT_layouts.hpp b/common/src/KokkosFFT_layouts.hpp index a555c990..85441edc 100644 --- a/common/src/KokkosFFT_layouts.hpp +++ b/common/src/KokkosFFT_layouts.hpp @@ -67,8 +67,12 @@ auto get_extents(const InViewType& in, const OutViewType& out, if (is_real_v) { // Then R2C if (is_complex_v) { - assert(_out_extents.at(inner_most_axis) == - _in_extents.at(inner_most_axis) / 2 + 1); + if (_out_extents.at(inner_most_axis) != + _in_extents.at(inner_most_axis) / 2 + 1) { + throw std::runtime_error( + "For R2C, the output extent of transform should be input extent / " + "2 + 1"); + } } else { throw std::runtime_error( "If the input type is real, the output type should be complex"); @@ -78,8 +82,12 @@ auto get_extents(const InViewType& in, const OutViewType& out, if (is_real_v) { // Then C2R if (is_complex_v) { - assert(_in_extents.at(inner_most_axis) == - _out_extents.at(inner_most_axis) / 2 + 1); + if (_in_extents.at(inner_most_axis) != + _out_extents.at(inner_most_axis) / 2 + 1) { + throw std::runtime_error( + "For C2R, the input extent of transform should be output extent / " + "2 + 1"); + } } else { throw std::runtime_error( "If the output type is real, the input type should be complex"); diff --git a/common/src/KokkosFFT_padding.hpp b/common/src/KokkosFFT_padding.hpp index b1a72a27..b672d1c2 100644 --- a/common/src/KokkosFFT_padding.hpp +++ b/common/src/KokkosFFT_padding.hpp @@ -51,8 +51,14 @@ auto get_modified_shape(const InViewType in, const OutViewType /* out */, } // Assert if the elements are overlapped - assert(!KokkosFFT::Impl::has_duplicate_values(positive_axes)); - assert(!KokkosFFT::Impl::is_out_of_range_value_included(positive_axes, rank)); + if (KokkosFFT::Impl::has_duplicate_values(positive_axes)) { + throw std::runtime_error("get_modified_shape: axes are overlapped."); + } + if (KokkosFFT::Impl::is_out_of_range_value_included(positive_axes, rank)) { + throw std::runtime_error( + "get_modified_shape: axes include out of range index." + "axes should be in the range of [-rank, rank-1]."); + } using full_shape_type = shape_type; full_shape_type modified_shape; diff --git a/common/src/KokkosFFT_utils.hpp b/common/src/KokkosFFT_utils.hpp index dbc51bfd..b7ff0c67 100644 --- a/common/src/KokkosFFT_utils.hpp +++ b/common/src/KokkosFFT_utils.hpp @@ -26,7 +26,10 @@ auto convert_negative_axis(ViewType, int _axis = -1) { static_assert(Kokkos::is_view::value, "convert_negative_axis: ViewType is not a Kokkos::View."); int rank = static_cast(ViewType::rank()); - assert(_axis >= -rank && _axis < rank); // axis should be in [-rank, rank-1] + if (_axis < -rank || _axis >= rank) { + throw std::runtime_error("axis should be in [-rank, rank-1]"); + } + int axis = _axis < 0 ? rank + _axis : _axis; return axis; } From 0af65e3dcecbbb5eca9d3bd027524f06d3b02765 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Wed, 24 Jul 2024 21:13:27 +0900 Subject: [PATCH 2/6] Add corresponding throw tests --- common/unit_test/Test_Layouts.cpp | 158 ++++++++++++++++++++++++++++++ common/unit_test/Test_Utils.cpp | 44 +++++++++ 2 files changed, 202 insertions(+) diff --git a/common/unit_test/Test_Layouts.cpp b/common/unit_test/Test_Layouts.cpp index e8a3d4ca..ebf4c220 100644 --- a/common/unit_test/Test_Layouts.cpp +++ b/common/unit_test/Test_Layouts.cpp @@ -52,6 +52,10 @@ void test_layouts_1d() { EXPECT_TRUE(fft_extents_r2c == ref_fft_extents_r2c); EXPECT_EQ(howmany_r2c, 1); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW({ KokkosFFT::Impl::get_extents(xr, xcout, axes_type({0})); }, + std::runtime_error); + // C2R std::vector ref_in_extents_c2r(1), ref_out_extents_c2r(1), ref_fft_extents_c2r(1); @@ -66,6 +70,10 @@ void test_layouts_1d() { EXPECT_TRUE(fft_extents_c2r == ref_fft_extents_c2r); EXPECT_EQ(howmany_c2r, 1); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW({ KokkosFFT::Impl::get_extents(xcin, xr, axes_type({0})); }, + std::runtime_error); + // C2C std::vector ref_in_extents_c2c(1), ref_out_extents_c2c(1), ref_fft_extents_c2c(1); @@ -111,6 +119,10 @@ void test_layouts_1d_batched_FFT_2d() { EXPECT_TRUE(out_extents_r2c_axis0 == ref_out_extents_r2c_axis0); EXPECT_EQ(howmany_r2c_axis0, ref_howmany_r2c_axis0); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW({ KokkosFFT::Impl::get_extents(xr2, xcout2, axes_type({0})); }, + std::runtime_error); + auto [in_extents_r2c_axis1, out_extents_r2c_axis1, fft_extents_r2c_axis1, howmany_r2c_axis1] = KokkosFFT::Impl::get_extents(xr2, xc2_axis1, axes_type({1})); @@ -119,6 +131,10 @@ void test_layouts_1d_batched_FFT_2d() { EXPECT_TRUE(out_extents_r2c_axis1 == ref_out_extents_r2c_axis1); EXPECT_EQ(howmany_r2c_axis1, ref_howmany_r2c_axis1); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW({ KokkosFFT::Impl::get_extents(xr2, xcout2, axes_type({1})); }, + std::runtime_error); + // C2R auto [in_extents_c2r_axis0, out_extents_c2r_axis0, fft_extents_c2r_axis0, howmany_c2r_axis0] = @@ -128,6 +144,10 @@ void test_layouts_1d_batched_FFT_2d() { EXPECT_TRUE(out_extents_c2r_axis0 == ref_in_extents_r2c_axis0); EXPECT_EQ(howmany_c2r_axis0, ref_howmany_r2c_axis0); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW({ KokkosFFT::Impl::get_extents(xcin2, xr2, axes_type({0})); }, + std::runtime_error); + auto [in_extents_c2r_axis1, out_extents_c2r_axis1, fft_extents_c2r_axis1, howmany_c2r_axis1] = KokkosFFT::Impl::get_extents(xc2_axis1, xr2, axes_type({1})); @@ -136,6 +156,10 @@ void test_layouts_1d_batched_FFT_2d() { EXPECT_TRUE(out_extents_c2r_axis1 == ref_in_extents_r2c_axis1); EXPECT_EQ(howmany_c2r_axis1, ref_howmany_r2c_axis1); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW({ KokkosFFT::Impl::get_extents(xcin2, xr2, axes_type({1})); }, + std::runtime_error); + // C2C auto [in_extents_c2c_axis0, out_extents_c2c_axis0, fft_extents_c2c_axis0, howmany_c2c_axis0] = @@ -193,6 +217,10 @@ void test_layouts_1d_batched_FFT_3d() { EXPECT_TRUE(out_extents_r2c_axis0 == ref_out_extents_r2c_axis0); EXPECT_EQ(howmany_r2c_axis0, ref_howmany_r2c_axis0); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW({ KokkosFFT::Impl::get_extents(xr3, xcout3, axes_type({0})); }, + std::runtime_error); + auto [in_extents_r2c_axis1, out_extents_r2c_axis1, fft_extents_r2c_axis1, howmany_r2c_axis1] = KokkosFFT::Impl::get_extents(xr3, xc3_axis1, axes_type({1})); @@ -201,6 +229,10 @@ void test_layouts_1d_batched_FFT_3d() { EXPECT_TRUE(out_extents_r2c_axis1 == ref_out_extents_r2c_axis1); EXPECT_EQ(howmany_r2c_axis1, ref_howmany_r2c_axis1); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW({ KokkosFFT::Impl::get_extents(xr3, xcout3, axes_type({1})); }, + std::runtime_error); + auto [in_extents_r2c_axis2, out_extents_r2c_axis2, fft_extents_r2c_axis2, howmany_r2c_axis2] = KokkosFFT::Impl::get_extents(xr3, xc3_axis2, axes_type({2})); @@ -209,6 +241,10 @@ void test_layouts_1d_batched_FFT_3d() { EXPECT_TRUE(out_extents_r2c_axis2 == ref_out_extents_r2c_axis2); EXPECT_EQ(howmany_r2c_axis2, ref_howmany_r2c_axis2); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW({ KokkosFFT::Impl::get_extents(xr3, xcout3, axes_type({2})); }, + std::runtime_error); + // C2R auto [in_extents_c2r_axis0, out_extents_c2r_axis0, fft_extents_c2r_axis0, howmany_c2r_axis0] = @@ -218,6 +254,10 @@ void test_layouts_1d_batched_FFT_3d() { EXPECT_TRUE(out_extents_c2r_axis0 == ref_in_extents_r2c_axis0); EXPECT_EQ(howmany_c2r_axis0, ref_howmany_r2c_axis0); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW({ KokkosFFT::Impl::get_extents(xcin3, xr3, axes_type({0})); }, + std::runtime_error); + auto [in_extents_c2r_axis1, out_extents_c2r_axis1, fft_extents_c2r_axis1, howmany_c2r_axis1] = KokkosFFT::Impl::get_extents(xc3_axis1, xr3, axes_type({1})); @@ -226,6 +266,10 @@ void test_layouts_1d_batched_FFT_3d() { EXPECT_TRUE(out_extents_c2r_axis1 == ref_in_extents_r2c_axis1); EXPECT_EQ(howmany_c2r_axis1, ref_howmany_r2c_axis1); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW({ KokkosFFT::Impl::get_extents(xcin3, xr3, axes_type({1})); }, + std::runtime_error); + auto [in_extents_c2r_axis2, out_extents_c2r_axis2, fft_extents_c2r_axis2, howmany_c2r_axis2] = KokkosFFT::Impl::get_extents(xc3_axis2, xr3, axes_type({2})); @@ -234,6 +278,10 @@ void test_layouts_1d_batched_FFT_3d() { EXPECT_TRUE(out_extents_c2r_axis2 == ref_in_extents_r2c_axis2); EXPECT_EQ(howmany_c2r_axis2, ref_howmany_r2c_axis2); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW({ KokkosFFT::Impl::get_extents(xcin3, xr3, axes_type({2})); }, + std::runtime_error); + // C2C auto [in_extents_c2c_axis0, out_extents_c2c_axis0, fft_extents_c2c_axis0, howmany_c2c_axis0] = @@ -318,6 +366,19 @@ void test_layouts_2d() { EXPECT_EQ(howmany_r2c_axis01, 1); EXPECT_EQ(howmany_r2c_axis10, 1); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xr2, xcout2, axes_type({0, 1})); + }, + std::runtime_error); + + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xr2, xcout2, axes_type({1, 0})); + }, + std::runtime_error); + // C2R auto [in_extents_c2r_axis01, out_extents_c2r_axis01, fft_extents_c2r_axis01, howmany_c2r_axis01] = @@ -337,6 +398,19 @@ void test_layouts_2d() { EXPECT_EQ(howmany_c2r_axis01, 1); EXPECT_EQ(howmany_c2r_axis10, 1); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xcin2, xr2, axes_type({0, 1})); + }, + std::runtime_error); + + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xcin2, xr2, axes_type({1, 0})); + }, + std::runtime_error); + // C2C auto [in_extents_c2c_axis01, out_extents_c2c_axis01, fft_extents_c2c_axis01, howmany_c2c_axis01] = @@ -414,6 +488,13 @@ void test_layouts_2d_batched_FFT_3d() { EXPECT_TRUE(out_extents_r2c_axis_01 == ref_out_extents_r2c_axis_01); EXPECT_EQ(howmany_r2c_axis_01, ref_howmany_r2c_axis_01); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xr3, xcout3, axes_type({0, 1})); + }, + std::runtime_error); + auto [in_extents_r2c_axis_02, out_extents_r2c_axis_02, fft_extents_r2c_axis_02, howmany_r2c_axis_02] = KokkosFFT::Impl::get_extents(xr3, xc3_axis_02, axes_type({0, 2})); @@ -422,6 +503,13 @@ void test_layouts_2d_batched_FFT_3d() { EXPECT_TRUE(out_extents_r2c_axis_02 == ref_out_extents_r2c_axis_02); EXPECT_EQ(howmany_r2c_axis_02, ref_howmany_r2c_axis_02); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xr3, xcout3, axes_type({0, 2})); + }, + std::runtime_error); + auto [in_extents_r2c_axis_10, out_extents_r2c_axis_10, fft_extents_r2c_axis_10, howmany_r2c_axis_10] = KokkosFFT::Impl::get_extents(xr3, xc3_axis_10, axes_type({1, 0})); @@ -430,6 +518,13 @@ void test_layouts_2d_batched_FFT_3d() { EXPECT_TRUE(out_extents_r2c_axis_10 == ref_out_extents_r2c_axis_10); EXPECT_EQ(howmany_r2c_axis_10, ref_howmany_r2c_axis_10); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xr3, xcout3, axes_type({1, 0})); + }, + std::runtime_error); + auto [in_extents_r2c_axis_12, out_extents_r2c_axis_12, fft_extents_r2c_axis_12, howmany_r2c_axis_12] = KokkosFFT::Impl::get_extents(xr3, xc3_axis_12, axes_type({1, 2})); @@ -438,6 +533,13 @@ void test_layouts_2d_batched_FFT_3d() { EXPECT_TRUE(out_extents_r2c_axis_12 == ref_out_extents_r2c_axis_12); EXPECT_EQ(howmany_r2c_axis_12, ref_howmany_r2c_axis_12); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xr3, xcout3, axes_type({1, 2})); + }, + std::runtime_error); + auto [in_extents_r2c_axis_20, out_extents_r2c_axis_20, fft_extents_r2c_axis_20, howmany_r2c_axis_20] = KokkosFFT::Impl::get_extents(xr3, xc3_axis_20, axes_type({2, 0})); @@ -446,6 +548,13 @@ void test_layouts_2d_batched_FFT_3d() { EXPECT_TRUE(out_extents_r2c_axis_20 == ref_out_extents_r2c_axis_20); EXPECT_EQ(howmany_r2c_axis_20, ref_howmany_r2c_axis_20); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xr3, xcout3, axes_type({2, 0})); + }, + std::runtime_error); + auto [in_extents_r2c_axis_21, out_extents_r2c_axis_21, fft_extents_r2c_axis_21, howmany_r2c_axis_21] = KokkosFFT::Impl::get_extents(xr3, xc3_axis_21, axes_type({2, 1})); @@ -454,6 +563,13 @@ void test_layouts_2d_batched_FFT_3d() { EXPECT_TRUE(out_extents_r2c_axis_21 == ref_out_extents_r2c_axis_21); EXPECT_EQ(howmany_r2c_axis_21, ref_howmany_r2c_axis_21); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xr3, xcout3, axes_type({2, 1})); + }, + std::runtime_error); + // C2R auto [in_extents_c2r_axis_01, out_extents_c2r_axis_01, fft_extents_c2r_axis_01, howmany_c2r_axis_01] = @@ -463,6 +579,13 @@ void test_layouts_2d_batched_FFT_3d() { EXPECT_TRUE(out_extents_c2r_axis_01 == ref_in_extents_r2c_axis_01); EXPECT_EQ(howmany_c2r_axis_01, ref_howmany_r2c_axis_01); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xcin3, xr3, axes_type({0, 1})); + }, + std::runtime_error); + auto [in_extents_c2r_axis_02, out_extents_c2r_axis_02, fft_extents_c2r_axis_02, howmany_c2r_axis_02] = KokkosFFT::Impl::get_extents(xc3_axis_02, xr3, axes_type({0, 2})); @@ -471,6 +594,13 @@ void test_layouts_2d_batched_FFT_3d() { EXPECT_TRUE(out_extents_c2r_axis_02 == ref_in_extents_r2c_axis_02); EXPECT_EQ(howmany_c2r_axis_02, ref_howmany_r2c_axis_02); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xcin3, xr3, axes_type({0, 2})); + }, + std::runtime_error); + auto [in_extents_c2r_axis_10, out_extents_c2r_axis_10, fft_extents_c2r_axis_10, howmany_c2r_axis_10] = KokkosFFT::Impl::get_extents(xc3_axis_10, xr3, axes_type({1, 0})); @@ -479,6 +609,13 @@ void test_layouts_2d_batched_FFT_3d() { EXPECT_TRUE(out_extents_c2r_axis_10 == ref_in_extents_r2c_axis_10); EXPECT_EQ(howmany_c2r_axis_10, ref_howmany_r2c_axis_10); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xcin3, xr3, axes_type({1, 0})); + }, + std::runtime_error); + auto [in_extents_c2r_axis_12, out_extents_c2r_axis_12, fft_extents_c2r_axis_12, howmany_c2r_axis_12] = KokkosFFT::Impl::get_extents(xc3_axis_12, xr3, axes_type({1, 2})); @@ -487,6 +624,13 @@ void test_layouts_2d_batched_FFT_3d() { EXPECT_TRUE(out_extents_c2r_axis_12 == ref_in_extents_r2c_axis_12); EXPECT_EQ(howmany_c2r_axis_12, ref_howmany_r2c_axis_12); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xcin3, xr3, axes_type({1, 2})); + }, + std::runtime_error); + auto [in_extents_c2r_axis_20, out_extents_c2r_axis_20, fft_extents_c2r_axis_20, howmany_c2r_axis_20] = KokkosFFT::Impl::get_extents(xc3_axis_20, xr3, axes_type({2, 0})); @@ -495,6 +639,13 @@ void test_layouts_2d_batched_FFT_3d() { EXPECT_TRUE(out_extents_c2r_axis_20 == ref_in_extents_r2c_axis_20); EXPECT_EQ(howmany_c2r_axis_20, ref_howmany_r2c_axis_20); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xcin3, xr3, axes_type({2, 0})); + }, + std::runtime_error); + auto [in_extents_c2r_axis_21, out_extents_c2r_axis_21, fft_extents_c2r_axis_21, howmany_c2r_axis_21] = KokkosFFT::Impl::get_extents(xc3_axis_21, xr3, axes_type({2, 1})); @@ -503,6 +654,13 @@ void test_layouts_2d_batched_FFT_3d() { EXPECT_TRUE(out_extents_c2r_axis_21 == ref_in_extents_r2c_axis_21); EXPECT_EQ(howmany_c2r_axis_21, ref_howmany_r2c_axis_21); + // Check if errors are correctly raised aginst invalid extents + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xcin3, xr3, axes_type({2, 1})); + }, + std::runtime_error); + // C2C auto [in_extents_c2c_axis_01, out_extents_c2c_axis_01, fft_extents_c2c_axis_01, howmany_c2c_axis_01] = diff --git a/common/unit_test/Test_Utils.cpp b/common/unit_test/Test_Utils.cpp index 2322443c..b443a112 100644 --- a/common/unit_test/Test_Utils.cpp +++ b/common/unit_test/Test_Utils.cpp @@ -38,6 +38,14 @@ void test_convert_negative_axes_1d() { EXPECT_EQ(converted_axis_0, ref_converted_axis_0); EXPECT_EQ(converted_axis_minus1, ref_converted_axis_minus1); + + // Check if errors are correctly raised aginst invalid axis + // axis must be in [-1, 1) + EXPECT_THROW({ KokkosFFT::Impl::convert_negative_axis(x, /*axis=*/1); }, + std::runtime_error); + + EXPECT_THROW({ KokkosFFT::Impl::convert_negative_axis(x, /*axis=*/-2); }, + std::runtime_error); } template @@ -58,6 +66,14 @@ void test_convert_negative_axes_2d() { EXPECT_EQ(converted_axis_0, ref_converted_axis_0); EXPECT_EQ(converted_axis_1, ref_converted_axis_1); EXPECT_EQ(converted_axis_minus1, ref_converted_axis_minus1); + + // Check if errors are correctly raised aginst invalid axis + // axis must be in [-2, 2) + EXPECT_THROW({ KokkosFFT::Impl::convert_negative_axis(x, /*axis=*/2); }, + std::runtime_error); + + EXPECT_THROW({ KokkosFFT::Impl::convert_negative_axis(x, /*axis=*/-3); }, + std::runtime_error); } template @@ -85,6 +101,14 @@ void test_convert_negative_axes_3d() { EXPECT_EQ(converted_axis_2, ref_converted_axis_2); EXPECT_EQ(converted_axis_minus1, ref_converted_axis_minus1); EXPECT_EQ(converted_axis_minus2, ref_converted_axis_minus2); + + // Check if errors are correctly raised aginst invalid axis + // axis must be in [-3, 3) + EXPECT_THROW({ KokkosFFT::Impl::convert_negative_axis(x, /*axis=*/3); }, + std::runtime_error); + + EXPECT_THROW({ KokkosFFT::Impl::convert_negative_axis(x, /*axis=*/-4); }, + std::runtime_error); } template @@ -119,6 +143,14 @@ void test_convert_negative_axes_4d() { EXPECT_EQ(converted_axis_minus1, ref_converted_axis_minus1); EXPECT_EQ(converted_axis_minus2, ref_converted_axis_minus2); EXPECT_EQ(converted_axis_minus3, ref_converted_axis_minus3); + + // Check if errors are correctly raised aginst invalid axis + // axis must be in [-4, 4) + EXPECT_THROW({ KokkosFFT::Impl::convert_negative_axis(x, /*axis=*/4); }, + std::runtime_error); + + EXPECT_THROW({ KokkosFFT::Impl::convert_negative_axis(x, /*axis=*/-5); }, + std::runtime_error); } // Tests for 1D View @@ -249,6 +281,18 @@ TEST(GetIndex, Vectors) { EXPECT_THROW(KokkosFFT::Impl::get_index(v, 5), std::runtime_error); } +TEST(HasDuplicateValues, Array) { + std::vector v0 = {0, 1, 1}; + std::vector v1 = {0, 1, 1, 1}; + std::vector v2 = {0, 1, 2, 3}; + std::vector v3 = {0}; + + EXPECT_TRUE(KokkosFFT::Impl::has_duplicate_values(v0)); + EXPECT_TRUE(KokkosFFT::Impl::has_duplicate_values(v1)); + EXPECT_FALSE(KokkosFFT::Impl::has_duplicate_values(v2)); + EXPECT_FALSE(KokkosFFT::Impl::has_duplicate_values(v3)); +} + TEST(IsOutOfRangeValueIncluded, Array) { std::vector v = {0, 1, 2, 3}; From 1528955e474e841dc8b5429409707b9781bf6683 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Thu, 25 Jul 2024 16:19:55 +0900 Subject: [PATCH 3/6] use check_precondition in get_shift function --- common/src/KokkosFFT_Helpers.hpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/common/src/KokkosFFT_Helpers.hpp b/common/src/KokkosFFT_Helpers.hpp index 76d40edf..86e5f95c 100644 --- a/common/src/KokkosFFT_Helpers.hpp +++ b/common/src/KokkosFFT_Helpers.hpp @@ -23,14 +23,12 @@ auto get_shift(const ViewType& inout, axis_type _axes, int direction = 1) { // Assert if the elements are overlapped constexpr int rank = ViewType::rank(); - if (KokkosFFT::Impl::has_duplicate_values(axes)) { - throw std::runtime_error("get_shift: axes are overlapped."); - } - if (KokkosFFT::Impl::is_out_of_range_value_included(axes, rank)) { - throw std::runtime_error( - "get_shift: axes include out of range index." - "axes should be in the range of [-rank, rank-1]."); - } + check_precondition(!KokkosFFT::Impl::has_duplicate_values(axes), + "axes are overlapped"); + check_precondition( + !KokkosFFT::Impl::is_out_of_range_value_included(axes, rank), + "axes include out of range index." + "axes should be in the range of [-rank, rank-1]."); axis_type shift = {0}; for (int i = 0; i < static_cast(DIM); i++) { From edbd4836d959861f6b1724acd1273c690a23acf1 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 30 Jul 2024 01:58:05 +0900 Subject: [PATCH 4/6] fix conflicts --- common/src/KokkosFFT_utils.hpp | 41 ++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/common/src/KokkosFFT_utils.hpp b/common/src/KokkosFFT_utils.hpp index b7ff0c67..00bd6764 100644 --- a/common/src/KokkosFFT_utils.hpp +++ b/common/src/KokkosFFT_utils.hpp @@ -14,13 +14,54 @@ #if defined(KOKKOS_ENABLE_CXX17) #include +#define KOKKOSFFT_EXPECTS(expression, msg) \ + KokkosFFT::Impl::check_precondition((expression), msg, __FILE__, __LINE__, \ + __FUNCTION__) #else #include +#define KOKKOSFFT_EXPECTS(expression, msg) \ + KokkosFFT::Impl::check_precondition( \ + (expression), msg, std::source_location::current().file_name(), \ + std::source_location::current().line(), \ + std::source_location::current().function_name(), \ + std::source_location::current().column()) #endif namespace KokkosFFT { namespace Impl { +inline void check_precondition(const bool expression, + [[maybe_unused]] const std::string& msg, + [[maybe_unused]] const char* file_name, int line, + [[maybe_unused]] const char* function_name, + [[maybe_unused]] const int column = -1) { + // Quick return if possible + if (expression) return; + + std::stringstream ss("file: "); + if (column == -1) { + // For C++ 17 + ss << file_name << '(' << line << ") `" << function_name << "`: " << msg + << '\n'; + } else { + // For C++ 20 and later + ss << file_name << '(' << line << ':' << column << ") `" << function_name + << "`: " << msg << '\n'; + } + throw std::runtime_error(ss.str()); +} +inline void check_precondition(const bool expression, const std::string& msg, + const char* file_name, int line, + const char* function_name) { + std::stringstream ss("file: "); + ss << file_name << '(' << line << ") `" << function_name << "`: " << msg + << '\n'; + if (!expression) { + throw std::runtime_error(ss.str()); + } +} +#endif + template auto convert_negative_axis(ViewType, int _axis = -1) { static_assert(Kokkos::is_view::value, From 07282b99dcf664985adb32eba732acf6f4aa5b61 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 30 Jul 2024 02:03:50 +0900 Subject: [PATCH 5/6] fix conflicts --- common/src/KokkosFFT_Helpers.hpp | 10 +++++----- common/src/KokkosFFT_layouts.hpp | 22 ++++++++++------------ common/src/KokkosFFT_padding.hpp | 14 ++++++-------- common/src/KokkosFFT_utils.hpp | 17 ++++++----------- 4 files changed, 27 insertions(+), 36 deletions(-) diff --git a/common/src/KokkosFFT_Helpers.hpp b/common/src/KokkosFFT_Helpers.hpp index 86e5f95c..d05d2cbc 100644 --- a/common/src/KokkosFFT_Helpers.hpp +++ b/common/src/KokkosFFT_Helpers.hpp @@ -23,12 +23,12 @@ auto get_shift(const ViewType& inout, axis_type _axes, int direction = 1) { // Assert if the elements are overlapped constexpr int rank = ViewType::rank(); - check_precondition(!KokkosFFT::Impl::has_duplicate_values(axes), - "axes are overlapped"); - check_precondition( + KOKKOSFFT_EXPECTS(!KokkosFFT::Impl::has_duplicate_values(axes), + "Axes overlap"); + KOKKOSFFT_EXPECTS( !KokkosFFT::Impl::is_out_of_range_value_included(axes, rank), - "axes include out of range index." - "axes should be in the range of [-rank, rank-1]."); + "Axes include an out-of-range index." + "Axes must be in the range of [-rank, rank-1]."); axis_type shift = {0}; for (int i = 0; i < static_cast(DIM); i++) { diff --git a/common/src/KokkosFFT_layouts.hpp b/common/src/KokkosFFT_layouts.hpp index 85441edc..065b898c 100644 --- a/common/src/KokkosFFT_layouts.hpp +++ b/common/src/KokkosFFT_layouts.hpp @@ -67,12 +67,11 @@ auto get_extents(const InViewType& in, const OutViewType& out, if (is_real_v) { // Then R2C if (is_complex_v) { - if (_out_extents.at(inner_most_axis) != - _in_extents.at(inner_most_axis) / 2 + 1) { - throw std::runtime_error( - "For R2C, the output extent of transform should be input extent / " - "2 + 1"); - } + KOKKOSFFT_EXPECTS( + _out_extents.at(inner_most_axis) == + _in_extents.at(inner_most_axis) / 2 + 1, + "For R2C, the 'output extent' of transform must be equal to " + "'input extent'/2 + 1"); } else { throw std::runtime_error( "If the input type is real, the output type should be complex"); @@ -82,12 +81,11 @@ auto get_extents(const InViewType& in, const OutViewType& out, if (is_real_v) { // Then C2R if (is_complex_v) { - if (_in_extents.at(inner_most_axis) != - _out_extents.at(inner_most_axis) / 2 + 1) { - throw std::runtime_error( - "For C2R, the input extent of transform should be output extent / " - "2 + 1"); - } + KOKKOSFFT_EXPECTS( + _in_extents.at(inner_most_axis) == + _out_extents.at(inner_most_axis) / 2 + 1, + "For C2R, the 'input extent' of transform must be equal to " + "'output extent' / 2 + 1"); } else { throw std::runtime_error( "If the output type is real, the input type should be complex"); diff --git a/common/src/KokkosFFT_padding.hpp b/common/src/KokkosFFT_padding.hpp index b672d1c2..3fc059be 100644 --- a/common/src/KokkosFFT_padding.hpp +++ b/common/src/KokkosFFT_padding.hpp @@ -51,14 +51,12 @@ auto get_modified_shape(const InViewType in, const OutViewType /* out */, } // Assert if the elements are overlapped - if (KokkosFFT::Impl::has_duplicate_values(positive_axes)) { - throw std::runtime_error("get_modified_shape: axes are overlapped."); - } - if (KokkosFFT::Impl::is_out_of_range_value_included(positive_axes, rank)) { - throw std::runtime_error( - "get_modified_shape: axes include out of range index." - "axes should be in the range of [-rank, rank-1]."); - } + KOKKOSFFT_EXPECTS(!KokkosFFT::Impl::has_duplicate_values(positive_axes), + "Axes overlap"); + KOKKOSFFT_EXPECTS( + !KokkosFFT::Impl::is_out_of_range_value_included(positive_axes, rank), + "Axes include an out-of-range index." + "Axes must be in the range of [-rank, rank-1]."); using full_shape_type = shape_type; full_shape_type modified_shape; diff --git a/common/src/KokkosFFT_utils.hpp b/common/src/KokkosFFT_utils.hpp index 00bd6764..48492759 100644 --- a/common/src/KokkosFFT_utils.hpp +++ b/common/src/KokkosFFT_utils.hpp @@ -50,26 +50,21 @@ inline void check_precondition(const bool expression, } throw std::runtime_error(ss.str()); } -inline void check_precondition(const bool expression, const std::string& msg, - const char* file_name, int line, - const char* function_name) { - std::stringstream ss("file: "); - ss << file_name << '(' << line << ") `" << function_name << "`: " << msg - << '\n'; - if (!expression) { - throw std::runtime_error(ss.str()); - } -} -#endif template auto convert_negative_axis(ViewType, int _axis = -1) { static_assert(Kokkos::is_view::value, "convert_negative_axis: ViewType is not a Kokkos::View."); int rank = static_cast(ViewType::rank()); +<<<<<<< HEAD if (_axis < -rank || _axis >= rank) { throw std::runtime_error("axis should be in [-rank, rank-1]"); } +======= + + KOKKOSFFT_EXPECTS(_axis >= -rank && _axis < rank, + "Axis must be in [-rank, rank-1]"); +>>>>>>> a786585 (improve assertion) int axis = _axis < 0 ? rank + _axis : _axis; return axis; From dd8c5f9541d86c7d661a43b923ce33702ee58ebb Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 30 Jul 2024 02:08:26 +0900 Subject: [PATCH 6/6] fix conflicts --- common/src/KokkosFFT_utils.hpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/common/src/KokkosFFT_utils.hpp b/common/src/KokkosFFT_utils.hpp index 48492759..4159991f 100644 --- a/common/src/KokkosFFT_utils.hpp +++ b/common/src/KokkosFFT_utils.hpp @@ -56,15 +56,9 @@ auto convert_negative_axis(ViewType, int _axis = -1) { static_assert(Kokkos::is_view::value, "convert_negative_axis: ViewType is not a Kokkos::View."); int rank = static_cast(ViewType::rank()); -<<<<<<< HEAD - if (_axis < -rank || _axis >= rank) { - throw std::runtime_error("axis should be in [-rank, rank-1]"); - } -======= KOKKOSFFT_EXPECTS(_axis >= -rank && _axis < rank, "Axis must be in [-rank, rank-1]"); ->>>>>>> a786585 (improve assertion) int axis = _axis < 0 ? rank + _axis : _axis; return axis;