Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply traits to operations #120

Merged
merged 4 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions common/src/KokkosFFT_Helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <Kokkos_Core.hpp>
#include "KokkosFFT_common_types.hpp"
#include "KokkosFFT_traits.hpp"
#include "KokkosFFT_utils.hpp"

namespace KokkosFFT {
Expand Down Expand Up @@ -131,16 +132,6 @@ void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<2> shift,
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void fftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(Kokkos::is_view<ViewType>::value,
"fftshift_impl: ViewType is not a Kokkos::View.");
static_assert(
KokkosFFT::Impl::is_layout_left_or_right_v<ViewType>,
"fftshift_impl: ViewType must be either LayoutLeft or LayoutRight.");
static_assert(
Kokkos::SpaceAccessibility<ExecutionSpace,
typename ViewType::memory_space>::accessible,
"fftshift_impl: execution_space cannot access data in ViewType");

static_assert(ViewType::rank() >= DIM,
"fftshift_impl: Rank of View must be larger thane "
"or equal to the Rank of shift axes.");
Expand All @@ -151,16 +142,6 @@ void fftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void ifftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(Kokkos::is_view<ViewType>::value,
"ifftshift_impl: ViewType is not a Kokkos::View.");
static_assert(
KokkosFFT::Impl::is_layout_left_or_right_v<ViewType>,
"ifftshift_impl: ViewType must be either LayoutLeft or LayoutRight.");
static_assert(
Kokkos::SpaceAccessibility<ExecutionSpace,
typename ViewType::memory_space>::accessible,
"ifftshift_impl: execution_space cannot access data in ViewType");

static_assert(ViewType::rank() >= DIM,
"ifftshift_impl: Rank of View must be larger "
"thane or equal to the Rank of shift axes.");
Expand Down Expand Up @@ -243,6 +224,11 @@ auto rfftfreq(const ExecutionSpace&, const std::size_t n,
template <typename ExecutionSpace, typename ViewType>
void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
std::optional<int> axes = std::nullopt) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"fftshift: View value type must be float, double, "
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
if (axes) {
axis_type<1> _axes{axes.value()};
KokkosFFT::Impl::fftshift_impl(exec_space, inout, _axes);
Expand All @@ -262,6 +248,11 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"fftshift: View value type must be float, double, "
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
KokkosFFT::Impl::fftshift_impl(exec_space, inout, axes);
}

Expand All @@ -273,6 +264,11 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
template <typename ExecutionSpace, typename ViewType>
void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
std::optional<int> axes = std::nullopt) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"ifftshift: View value type must be float, double, "
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
if (axes) {
axis_type<1> _axes{axes.value()};
KokkosFFT::Impl::ifftshift_impl(exec_space, inout, _axes);
Expand All @@ -292,6 +288,11 @@ void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"ifftshift: View value type must be float, double, "
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
KokkosFFT::Impl::ifftshift_impl(exec_space, inout, axes);
}
} // namespace KokkosFFT
Expand Down
77 changes: 15 additions & 62 deletions fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <Kokkos_Core.hpp>
#include "KokkosFFT_default_types.hpp"
#include "KokkosFFT_traits.hpp"
#include "KokkosFFT_transpose.hpp"
#include "KokkosFFT_padding.hpp"
#include "KokkosFFT_utils.hpp"
Expand Down Expand Up @@ -158,33 +159,14 @@ class Plan {
OutViewType& out, KokkosFFT::Direction direction, int axis,
std::optional<std::size_t> n = std::nullopt)
: m_exec_space(exec_space), m_axes({axis}), m_direction(direction) {
static_assert(Kokkos::is_view<InViewType>::value,
"Plan::Plan: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"Plan::Plan: OutViewType is not a Kokkos::View.");
static_assert(
KokkosFFT::Impl::is_layout_left_or_right_v<InViewType>,
"Plan::Plan: InViewType must be either LayoutLeft or LayoutRight.");
static_assert(
KokkosFFT::Impl::is_layout_left_or_right_v<OutViewType>,
"Plan::Plan: OutViewType must be either LayoutLeft or LayoutRight.");

static_assert(InViewType::rank() == OutViewType::rank(),
"Plan::Plan: InViewType and OutViewType must have "
"the same rank.");
static_assert(std::is_same_v<typename InViewType::array_layout,
typename OutViewType::array_layout>,
"Plan::Plan: InViewType and OutViewType must have "
"the same Layout.");

static_assert(
Kokkos::SpaceAccessibility<
ExecutionSpace, typename InViewType::memory_space>::accessible,
"Plan::Plan: execution_space cannot access data in InViewType");
static_assert(
Kokkos::SpaceAccessibility<
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"Plan::Plan: execution_space cannot access data in OutViewType");
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
"Plan::Plan: InViewType and OutViewType must have the same base "
"floating point type (float/double), the same layout "
"(LayoutLeft/LayoutRight), "
"and the same rank. ExecutionSpace must be accessible to the data in "
"InViewType and OutViewType.");

if (KokkosFFT::Impl::is_real_v<in_value_type> &&
m_direction != KokkosFFT::Direction::forward) {
Expand Down Expand Up @@ -230,34 +212,14 @@ class Plan {
OutViewType& out, KokkosFFT::Direction direction,
axis_type<DIM> axes, shape_type<DIM> s = {0})
: m_exec_space(exec_space), m_axes(axes), m_direction(direction) {
static_assert(Kokkos::is_view<InViewType>::value,
"Plan::Plan: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"Plan::Plan: OutViewType is not a Kokkos::View.");
static_assert(
KokkosFFT::Impl::is_layout_left_or_right_v<InViewType>,
"Plan::Plan: InViewType must be either LayoutLeft or LayoutRight.");
static_assert(
KokkosFFT::Impl::is_layout_left_or_right_v<OutViewType>,
"Plan::Plan: OutViewType must be either LayoutLeft or LayoutRight.");

static_assert(InViewType::rank() == OutViewType::rank(),
"Plan::Plan: InViewType and OutViewType must have "
"the same rank.");

static_assert(std::is_same_v<typename InViewType::array_layout,
typename OutViewType::array_layout>,
"Plan::Plan: InViewType and OutViewType must have "
"the same Layout.");

static_assert(
Kokkos::SpaceAccessibility<
ExecutionSpace, typename InViewType::memory_space>::accessible,
"Plan::Plan: execution_space cannot access data in InViewType");
static_assert(
Kokkos::SpaceAccessibility<
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"Plan::Plan: execution_space cannot access data in OutViewType");
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
"Plan::Plan: InViewType and OutViewType must have the same base "
"floating point type (float/double), the same layout "
"(LayoutLeft/LayoutRight), "
"and the same rank. ExecutionSpace must be accessible to the data in "
"InViewType and OutViewType.");

if (std::is_floating_point<in_value_type>::value &&
m_direction != KokkosFFT::Direction::forward) {
Expand Down Expand Up @@ -302,15 +264,6 @@ class Plan {
/// \param out [in] Ouput data
template <typename InViewType2, typename OutViewType2>
void good(const InViewType2& in, const OutViewType2& out) const {
static_assert(
Kokkos::SpaceAccessibility<
ExecutionSpace, typename InViewType2::memory_space>::accessible,
"Plan::good: execution_space cannot access data in InViewType");
static_assert(
Kokkos::SpaceAccessibility<
ExecutionSpace, typename OutViewType2::memory_space>::accessible,
"Plan::good: execution_space cannot access data in OutViewType");

using nonConstInViewType2 = std::remove_cv_t<InViewType2>;
using nonConstOutViewType2 = std::remove_cv_t<OutViewType2>;
static_assert(std::is_same_v<nonConstInViewType2, nonConstInViewType>,
Expand Down
Loading
Loading