Skip to content

Commit 5713084

Browse files
yasahi-hpcYuuichi Asahi
and
Yuuichi Asahi
authored
Check View rank and FFT rank consistency (#121)
* Add a constant for maximum FFT dimension * Check view rank and fft rank consistency in all APIs --------- Co-authored-by: Yuuichi Asahi <[email protected]>
1 parent 73b2c77 commit 5713084

File tree

4 files changed

+73
-10
lines changed

4 files changed

+73
-10
lines changed

common/src/KokkosFFT_Helpers.hpp

+18-10
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@ namespace KokkosFFT {
1414
namespace Impl {
1515
template <typename ViewType, std::size_t DIM = 1>
1616
auto get_shift(const ViewType& inout, axis_type<DIM> _axes, int direction = 1) {
17-
static_assert(DIM > 0,
18-
"get_shift: Rank of shift axes must be "
19-
"larger than or equal to 1.");
20-
2117
// Convert the input axes to be in the range of [0, rank-1]
2218
std::vector<int> axes;
2319
for (std::size_t i = 0; i < DIM; i++) {
@@ -132,19 +128,13 @@ void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<2> shift,
132128
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
133129
void fftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
134130
axis_type<DIM> axes) {
135-
static_assert(ViewType::rank() >= DIM,
136-
"fftshift_impl: Rank of View must be larger thane "
137-
"or equal to the Rank of shift axes.");
138131
auto shift = get_shift(inout, axes);
139132
roll(exec_space, inout, shift, axes);
140133
}
141134

142135
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
143136
void ifftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
144137
axis_type<DIM> axes) {
145-
static_assert(ViewType::rank() >= DIM,
146-
"ifftshift_impl: Rank of View must be larger "
147-
"thane or equal to the Rank of shift axes.");
148138
auto shift = get_shift(inout, axes, -1);
149139
roll(exec_space, inout, shift, axes);
150140
}
@@ -229,6 +219,9 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
229219
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
230220
"Layout must be either LayoutLeft or LayoutRight. "
231221
"ExecutionSpace must be able to access data in ViewType");
222+
static_assert(ViewType::rank() >= 1,
223+
"fftshift: View rank must be larger than or equal to 1");
224+
232225
if (axes) {
233226
axis_type<1> _axes{axes.value()};
234227
KokkosFFT::Impl::fftshift_impl(exec_space, inout, _axes);
@@ -253,6 +246,12 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
253246
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
254247
"Layout must be either LayoutLeft or LayoutRight. "
255248
"ExecutionSpace must be able to access data in ViewType");
249+
static_assert(
250+
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
251+
"fftshift: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
252+
static_assert(ViewType::rank() >= DIM,
253+
"fftshift: View rank must be larger than or equal to the Rank "
254+
"of FFT axes");
256255
KokkosFFT::Impl::fftshift_impl(exec_space, inout, axes);
257256
}
258257

@@ -269,6 +268,8 @@ void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
269268
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
270269
"Layout must be either LayoutLeft or LayoutRight. "
271270
"ExecutionSpace must be able to access data in ViewType");
271+
static_assert(ViewType::rank() >= 1,
272+
"ifftshift: View rank must be larger than or equal to 1");
272273
if (axes) {
273274
axis_type<1> _axes{axes.value()};
274275
KokkosFFT::Impl::ifftshift_impl(exec_space, inout, _axes);
@@ -293,6 +294,13 @@ void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
293294
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
294295
"Layout must be either LayoutLeft or LayoutRight. "
295296
"ExecutionSpace must be able to access data in ViewType");
297+
static_assert(
298+
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
299+
"ifftshift: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
300+
static_assert(ViewType::rank() >= DIM,
301+
"ifftshift: View rank must be larger than or equal to the Rank "
302+
"of FFT axes");
303+
296304
KokkosFFT::Impl::ifftshift_impl(exec_space, inout, axes);
297305
}
298306
} // namespace KokkosFFT

common/src/KokkosFFT_common_types.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ enum class Direction {
3434
backward,
3535
};
3636

37+
//! Maximum FFT dimension allowed in KokkosFFT
38+
constexpr std::size_t MAX_FFT_DIM = 3;
39+
3740
} // namespace KokkosFFT
3841

3942
#endif

fft/src/KokkosFFT_Plans.hpp

+8
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ class Plan {
167167
"(LayoutLeft/LayoutRight), "
168168
"and the same rank. ExecutionSpace must be accessible to the data in "
169169
"InViewType and OutViewType.");
170+
static_assert(InViewType::rank() >= 1,
171+
"Plan::Plan: View rank must be larger than or equal to 1");
170172

171173
if (KokkosFFT::Impl::is_real_v<in_value_type> &&
172174
m_direction != KokkosFFT::Direction::forward) {
@@ -220,6 +222,12 @@ class Plan {
220222
"(LayoutLeft/LayoutRight), "
221223
"and the same rank. ExecutionSpace must be accessible to the data in "
222224
"InViewType and OutViewType.");
225+
static_assert(
226+
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
227+
"Plan::Plan: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
228+
static_assert(InViewType::rank() >= DIM,
229+
"Plan::Plan: View rank must be larger than or equal to the "
230+
"Rank of FFT axes");
223231

224232
if (std::is_floating_point<in_value_type>::value &&
225233
m_direction != KokkosFFT::Direction::forward) {

fft/src/KokkosFFT_Transform.hpp

+44
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ void fft(const ExecutionSpace& exec_space, const InViewType& in,
139139
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
140140
"same rank. ExecutionSpace must be accessible to the data in InViewType "
141141
"and OutViewType.");
142+
static_assert(InViewType::rank() >= 1,
143+
"fft: View rank must be larger than or equal to 1");
142144

143145
KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward,
144146
axis, n);
@@ -165,6 +167,8 @@ void ifft(const ExecutionSpace& exec_space, const InViewType& in,
165167
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
166168
"same rank. ExecutionSpace must be accessible to the data in InViewType "
167169
"and OutViewType.");
170+
static_assert(InViewType::rank() >= 1,
171+
"ifft: View rank must be larger than or equal to 1");
168172

169173
KokkosFFT::Impl::Plan plan(exec_space, in, out,
170174
KokkosFFT::Direction::backward, axis, n);
@@ -191,6 +195,8 @@ void rfft(const ExecutionSpace& exec_space, const InViewType& in,
191195
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
192196
"same rank. ExecutionSpace must be accessible to the data in InViewType "
193197
"and OutViewType.");
198+
static_assert(InViewType::rank() >= 1,
199+
"rfft: View rank must be larger than or equal to 1");
194200

195201
using in_value_type = typename InViewType::non_const_value_type;
196202
using out_value_type = typename OutViewType::non_const_value_type;
@@ -224,6 +230,8 @@ void irfft(const ExecutionSpace& exec_space, const InViewType& in,
224230
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
225231
"same rank. ExecutionSpace must be accessible to the data in InViewType "
226232
"and OutViewType.");
233+
static_assert(InViewType::rank() >= 1,
234+
"irfft: View rank must be larger than or equal to 1");
227235

228236
using in_value_type = typename InViewType::non_const_value_type;
229237
using out_value_type = typename OutViewType::non_const_value_type;
@@ -255,6 +263,8 @@ void hfft(const ExecutionSpace& exec_space, const InViewType& in,
255263
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
256264
"same rank. ExecutionSpace must be accessible to the data in InViewType "
257265
"and OutViewType.");
266+
static_assert(InViewType::rank() >= 1,
267+
"hfft: View rank must be larger than or equal to 1");
258268

259269
// [TO DO]
260270
// allow real type as input, need to obtain complex view type from in view
@@ -295,6 +305,8 @@ void ihfft(const ExecutionSpace& exec_space, const InViewType& in,
295305
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
296306
"same rank. ExecutionSpace must be accessible to the data in InViewType "
297307
"and OutViewType.");
308+
static_assert(InViewType::rank() >= 1,
309+
"ihfft: View rank must be larger than or equal to 1");
298310

299311
using in_value_type = typename InViewType::non_const_value_type;
300312
using out_value_type = typename OutViewType::non_const_value_type;
@@ -332,6 +344,8 @@ void fft2(const ExecutionSpace& exec_space, const InViewType& in,
332344
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
333345
"same rank. ExecutionSpace must be accessible to the data in InViewType "
334346
"and OutViewType.");
347+
static_assert(InViewType::rank() >= 2,
348+
"fft2: View rank must be larger than or equal to 2");
335349

336350
KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward,
337351
axes, s);
@@ -359,6 +373,8 @@ void ifft2(const ExecutionSpace& exec_space, const InViewType& in,
359373
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
360374
"same rank. ExecutionSpace must be accessible to the data in InViewType "
361375
"and OutViewType.");
376+
static_assert(InViewType::rank() >= 2,
377+
"ifft2: View rank must be larger than or equal to 2");
362378

363379
KokkosFFT::Impl::Plan plan(exec_space, in, out,
364380
KokkosFFT::Direction::backward, axes, s);
@@ -386,6 +402,9 @@ void rfft2(const ExecutionSpace& exec_space, const InViewType& in,
386402
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
387403
"same rank. ExecutionSpace must be accessible to the data in InViewType "
388404
"and OutViewType.");
405+
static_assert(InViewType::rank() >= 2,
406+
"rfft2: View rank must be larger than or equal to 2");
407+
389408
using in_value_type = typename InViewType::non_const_value_type;
390409
using out_value_type = typename OutViewType::non_const_value_type;
391410

@@ -418,6 +437,8 @@ void irfft2(const ExecutionSpace& exec_space, const InViewType& in,
418437
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
419438
"same rank. ExecutionSpace must be accessible to the data in InViewType "
420439
"and OutViewType.");
440+
static_assert(InViewType::rank() >= 2,
441+
"irfft2: View rank must be larger than or equal to 2");
421442

422443
using in_value_type = typename InViewType::non_const_value_type;
423444
using out_value_type = typename OutViewType::non_const_value_type;
@@ -453,6 +474,11 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in,
453474
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
454475
"same rank. ExecutionSpace must be accessible to the data in InViewType "
455476
"and OutViewType.");
477+
static_assert(DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
478+
"fftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
479+
static_assert(
480+
InViewType::rank() >= DIM,
481+
"fftn: View rank must be larger than or equal to the Rank of FFT axes");
456482

457483
KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward,
458484
axes, s);
@@ -481,6 +507,12 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
481507
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
482508
"same rank. ExecutionSpace must be accessible to the data in InViewType "
483509
"and OutViewType.");
510+
static_assert(
511+
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
512+
"ifftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
513+
static_assert(
514+
InViewType::rank() >= DIM,
515+
"ifftn: View rank must be larger than or equal to the Rank of FFT axes");
484516

485517
KokkosFFT::Impl::Plan plan(exec_space, in, out,
486518
KokkosFFT::Direction::backward, axes, s);
@@ -509,6 +541,12 @@ void rfftn(const ExecutionSpace& exec_space, const InViewType& in,
509541
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
510542
"same rank. ExecutionSpace must be accessible to the data in InViewType "
511543
"and OutViewType.");
544+
static_assert(
545+
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
546+
"rfftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
547+
static_assert(
548+
InViewType::rank() >= DIM,
549+
"rfftn: View rank must be larger than or equal to the Rank of FFT axes");
512550

513551
using in_value_type = typename InViewType::non_const_value_type;
514552
using out_value_type = typename OutViewType::non_const_value_type;
@@ -543,6 +581,12 @@ void irfftn(const ExecutionSpace& exec_space, const InViewType& in,
543581
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
544582
"same rank. ExecutionSpace must be accessible to the data in InViewType "
545583
"and OutViewType.");
584+
static_assert(
585+
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
586+
"irfftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
587+
static_assert(
588+
InViewType::rank() >= DIM,
589+
"irfftn: View rank must be larger than or equal to the Rank of FFT axes");
546590

547591
using in_value_type = typename InViewType::non_const_value_type;
548592
using out_value_type = typename OutViewType::non_const_value_type;

0 commit comments

Comments
 (0)