Skip to content

Commit 73b2c77

Browse files
yasahi-hpcYuuichi Asahi
and
Yuuichi Asahi
authored
Apply traits to operations (#120)
* Strict type checks in uunary operations * Strict type checks in binary operations * Strict tpe checks in Plan creation * fix: missing execution space type --------- Co-authored-by: Yuuichi Asahi <[email protected]>
1 parent bab1620 commit 73b2c77

File tree

3 files changed

+137
-462
lines changed

3 files changed

+137
-462
lines changed

common/src/KokkosFFT_Helpers.hpp

+21-20
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <Kokkos_Core.hpp>
99
#include "KokkosFFT_common_types.hpp"
10+
#include "KokkosFFT_traits.hpp"
1011
#include "KokkosFFT_utils.hpp"
1112

1213
namespace KokkosFFT {
@@ -131,16 +132,6 @@ void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<2> shift,
131132
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
132133
void fftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
133134
axis_type<DIM> axes) {
134-
static_assert(Kokkos::is_view<ViewType>::value,
135-
"fftshift_impl: ViewType is not a Kokkos::View.");
136-
static_assert(
137-
KokkosFFT::Impl::is_layout_left_or_right_v<ViewType>,
138-
"fftshift_impl: ViewType must be either LayoutLeft or LayoutRight.");
139-
static_assert(
140-
Kokkos::SpaceAccessibility<ExecutionSpace,
141-
typename ViewType::memory_space>::accessible,
142-
"fftshift_impl: execution_space cannot access data in ViewType");
143-
144135
static_assert(ViewType::rank() >= DIM,
145136
"fftshift_impl: Rank of View must be larger thane "
146137
"or equal to the Rank of shift axes.");
@@ -151,16 +142,6 @@ void fftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
151142
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
152143
void ifftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
153144
axis_type<DIM> axes) {
154-
static_assert(Kokkos::is_view<ViewType>::value,
155-
"ifftshift_impl: ViewType is not a Kokkos::View.");
156-
static_assert(
157-
KokkosFFT::Impl::is_layout_left_or_right_v<ViewType>,
158-
"ifftshift_impl: ViewType must be either LayoutLeft or LayoutRight.");
159-
static_assert(
160-
Kokkos::SpaceAccessibility<ExecutionSpace,
161-
typename ViewType::memory_space>::accessible,
162-
"ifftshift_impl: execution_space cannot access data in ViewType");
163-
164145
static_assert(ViewType::rank() >= DIM,
165146
"ifftshift_impl: Rank of View must be larger "
166147
"thane or equal to the Rank of shift axes.");
@@ -243,6 +224,11 @@ auto rfftfreq(const ExecutionSpace&, const std::size_t n,
243224
template <typename ExecutionSpace, typename ViewType>
244225
void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
245226
std::optional<int> axes = std::nullopt) {
227+
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
228+
"fftshift: View value type must be float, double, "
229+
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
230+
"Layout must be either LayoutLeft or LayoutRight. "
231+
"ExecutionSpace must be able to access data in ViewType");
246232
if (axes) {
247233
axis_type<1> _axes{axes.value()};
248234
KokkosFFT::Impl::fftshift_impl(exec_space, inout, _axes);
@@ -262,6 +248,11 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
262248
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
263249
void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
264250
axis_type<DIM> axes) {
251+
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
252+
"fftshift: View value type must be float, double, "
253+
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
254+
"Layout must be either LayoutLeft or LayoutRight. "
255+
"ExecutionSpace must be able to access data in ViewType");
265256
KokkosFFT::Impl::fftshift_impl(exec_space, inout, axes);
266257
}
267258

@@ -273,6 +264,11 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
273264
template <typename ExecutionSpace, typename ViewType>
274265
void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
275266
std::optional<int> axes = std::nullopt) {
267+
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
268+
"ifftshift: View value type must be float, double, "
269+
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
270+
"Layout must be either LayoutLeft or LayoutRight. "
271+
"ExecutionSpace must be able to access data in ViewType");
276272
if (axes) {
277273
axis_type<1> _axes{axes.value()};
278274
KokkosFFT::Impl::ifftshift_impl(exec_space, inout, _axes);
@@ -292,6 +288,11 @@ void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
292288
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
293289
void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
294290
axis_type<DIM> axes) {
291+
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
292+
"ifftshift: View value type must be float, double, "
293+
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
294+
"Layout must be either LayoutLeft or LayoutRight. "
295+
"ExecutionSpace must be able to access data in ViewType");
295296
KokkosFFT::Impl::ifftshift_impl(exec_space, inout, axes);
296297
}
297298
} // namespace KokkosFFT

fft/src/KokkosFFT_Plans.hpp

+15-62
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include <Kokkos_Core.hpp>
1515
#include "KokkosFFT_default_types.hpp"
16+
#include "KokkosFFT_traits.hpp"
1617
#include "KokkosFFT_transpose.hpp"
1718
#include "KokkosFFT_padding.hpp"
1819
#include "KokkosFFT_utils.hpp"
@@ -158,33 +159,14 @@ class Plan {
158159
OutViewType& out, KokkosFFT::Direction direction, int axis,
159160
std::optional<std::size_t> n = std::nullopt)
160161
: m_exec_space(exec_space), m_axes({axis}), m_direction(direction) {
161-
static_assert(Kokkos::is_view<InViewType>::value,
162-
"Plan::Plan: InViewType is not a Kokkos::View.");
163-
static_assert(Kokkos::is_view<OutViewType>::value,
164-
"Plan::Plan: OutViewType is not a Kokkos::View.");
165162
static_assert(
166-
KokkosFFT::Impl::is_layout_left_or_right_v<InViewType>,
167-
"Plan::Plan: InViewType must be either LayoutLeft or LayoutRight.");
168-
static_assert(
169-
KokkosFFT::Impl::is_layout_left_or_right_v<OutViewType>,
170-
"Plan::Plan: OutViewType must be either LayoutLeft or LayoutRight.");
171-
172-
static_assert(InViewType::rank() == OutViewType::rank(),
173-
"Plan::Plan: InViewType and OutViewType must have "
174-
"the same rank.");
175-
static_assert(std::is_same_v<typename InViewType::array_layout,
176-
typename OutViewType::array_layout>,
177-
"Plan::Plan: InViewType and OutViewType must have "
178-
"the same Layout.");
179-
180-
static_assert(
181-
Kokkos::SpaceAccessibility<
182-
ExecutionSpace, typename InViewType::memory_space>::accessible,
183-
"Plan::Plan: execution_space cannot access data in InViewType");
184-
static_assert(
185-
Kokkos::SpaceAccessibility<
186-
ExecutionSpace, typename OutViewType::memory_space>::accessible,
187-
"Plan::Plan: execution_space cannot access data in OutViewType");
163+
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
164+
OutViewType>,
165+
"Plan::Plan: InViewType and OutViewType must have the same base "
166+
"floating point type (float/double), the same layout "
167+
"(LayoutLeft/LayoutRight), "
168+
"and the same rank. ExecutionSpace must be accessible to the data in "
169+
"InViewType and OutViewType.");
188170

189171
if (KokkosFFT::Impl::is_real_v<in_value_type> &&
190172
m_direction != KokkosFFT::Direction::forward) {
@@ -230,34 +212,14 @@ class Plan {
230212
OutViewType& out, KokkosFFT::Direction direction,
231213
axis_type<DIM> axes, shape_type<DIM> s = {0})
232214
: m_exec_space(exec_space), m_axes(axes), m_direction(direction) {
233-
static_assert(Kokkos::is_view<InViewType>::value,
234-
"Plan::Plan: InViewType is not a Kokkos::View.");
235-
static_assert(Kokkos::is_view<OutViewType>::value,
236-
"Plan::Plan: OutViewType is not a Kokkos::View.");
237-
static_assert(
238-
KokkosFFT::Impl::is_layout_left_or_right_v<InViewType>,
239-
"Plan::Plan: InViewType must be either LayoutLeft or LayoutRight.");
240-
static_assert(
241-
KokkosFFT::Impl::is_layout_left_or_right_v<OutViewType>,
242-
"Plan::Plan: OutViewType must be either LayoutLeft or LayoutRight.");
243-
244-
static_assert(InViewType::rank() == OutViewType::rank(),
245-
"Plan::Plan: InViewType and OutViewType must have "
246-
"the same rank.");
247-
248-
static_assert(std::is_same_v<typename InViewType::array_layout,
249-
typename OutViewType::array_layout>,
250-
"Plan::Plan: InViewType and OutViewType must have "
251-
"the same Layout.");
252-
253-
static_assert(
254-
Kokkos::SpaceAccessibility<
255-
ExecutionSpace, typename InViewType::memory_space>::accessible,
256-
"Plan::Plan: execution_space cannot access data in InViewType");
257215
static_assert(
258-
Kokkos::SpaceAccessibility<
259-
ExecutionSpace, typename OutViewType::memory_space>::accessible,
260-
"Plan::Plan: execution_space cannot access data in OutViewType");
216+
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
217+
OutViewType>,
218+
"Plan::Plan: InViewType and OutViewType must have the same base "
219+
"floating point type (float/double), the same layout "
220+
"(LayoutLeft/LayoutRight), "
221+
"and the same rank. ExecutionSpace must be accessible to the data in "
222+
"InViewType and OutViewType.");
261223

262224
if (std::is_floating_point<in_value_type>::value &&
263225
m_direction != KokkosFFT::Direction::forward) {
@@ -302,15 +264,6 @@ class Plan {
302264
/// \param out [in] Ouput data
303265
template <typename InViewType2, typename OutViewType2>
304266
void good(const InViewType2& in, const OutViewType2& out) const {
305-
static_assert(
306-
Kokkos::SpaceAccessibility<
307-
ExecutionSpace, typename InViewType2::memory_space>::accessible,
308-
"Plan::good: execution_space cannot access data in InViewType");
309-
static_assert(
310-
Kokkos::SpaceAccessibility<
311-
ExecutionSpace, typename OutViewType2::memory_space>::accessible,
312-
"Plan::good: execution_space cannot access data in OutViewType");
313-
314267
using nonConstInViewType2 = std::remove_cv_t<InViewType2>;
315268
using nonConstOutViewType2 = std::remove_cv_t<OutViewType2>;
316269
static_assert(std::is_same_v<nonConstInViewType2, nonConstInViewType>,

0 commit comments

Comments
 (0)