forked from kokkos/kokkos-fft
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathKokkosFFT_Helpers.hpp
308 lines (266 loc) · 11.3 KB
/
KokkosFFT_Helpers.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
// SPDX-FileCopyrightText: (C) The Kokkos-FFT development team, see COPYRIGHT.md file
//
// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception
#ifndef KOKKOSFFT_HELPERS_HPP
#define KOKKOSFFT_HELPERS_HPP
#include <Kokkos_Core.hpp>
#include "KokkosFFT_common_types.hpp"
#include "KokkosFFT_traits.hpp"
#include "KokkosFFT_utils.hpp"
namespace KokkosFFT {
namespace Impl {
template <typename ViewType, std::size_t DIM = 1>
auto get_shift(const ViewType& inout, axis_type<DIM> _axes, int direction = 1) {
// Convert the input axes to be in the range of [0, rank-1]
std::vector<int> axes;
for (std::size_t i = 0; i < DIM; i++) {
int axis = KokkosFFT::Impl::convert_negative_axis(inout, _axes.at(i));
axes.push_back(axis);
}
// Assert if the elements are overlapped
constexpr int rank = ViewType::rank();
assert(!KokkosFFT::Impl::has_duplicate_values(axes));
assert(!KokkosFFT::Impl::is_out_of_range_value_included(axes, rank));
axis_type<rank> shift = {0};
for (int i = 0; i < static_cast<int>(DIM); i++) {
int axis = axes.at(i);
shift.at(axis) = inout.extent(axis) / 2 * direction;
}
return shift;
}
template <typename ExecutionSpace, typename ViewType>
void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<1> shift,
axis_type<1>) {
// Last parameter is ignored but present for keeping the interface consistent
static_assert(ViewType::rank() == 1, "roll: Rank of View must be 1.");
std::size_t n0 = inout.extent(0);
ViewType tmp("tmp", n0);
std::size_t len = (n0 - 1) / 2 + 1;
auto [_shift0, _shift1, _shift2] =
KokkosFFT::Impl::convert_negative_shift(inout, shift.at(0), 0);
int shift0 = _shift0, shift1 = _shift1, shift2 = _shift2;
// shift2 == 0 means shift
if (shift2 == 0) {
Kokkos::parallel_for(
Kokkos::RangePolicy<ExecutionSpace, Kokkos::IndexType<std::size_t>>(
exec_space, 0, len),
KOKKOS_LAMBDA(std::size_t i) {
tmp(i + shift0) = inout(i);
if (i + shift1 < n0) {
tmp(i) = inout(i + shift1);
}
});
inout = tmp;
}
}
template <typename ExecutionSpace, typename ViewType, std::size_t DIM1 = 1>
void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<2> shift,
axis_type<DIM1> axes) {
constexpr int DIM0 = 2;
static_assert(ViewType::rank() == DIM0, "roll: Rank of View must be 2.");
int n0 = inout.extent(0), n1 = inout.extent(1);
ViewType tmp("tmp", n0, n1);
[[maybe_unused]] int len0 = (n0 - 1) / 2 + 1;
[[maybe_unused]] int len1 = (n1 - 1) / 2 + 1;
using range_type = Kokkos::MDRangePolicy<
ExecutionSpace,
Kokkos::Rank<2, Kokkos::Iterate::Default, Kokkos::Iterate::Default>>;
using tile_type = typename range_type::tile_type;
using point_type = typename range_type::point_type;
range_type range(
exec_space, point_type{{0, 0}}, point_type{{len0, len1}},
tile_type{{4, 4}} // [TO DO] Choose optimal tile sizes for each device
);
axis_type<2> shift0 = {0}, shift1 = {0}, shift2 = {n0 / 2, n1 / 2};
for (int i = 0; static_cast<std::size_t>(i) < DIM1; i++) {
int axis = axes.at(i);
auto [_shift0, _shift1, _shift2] =
KokkosFFT::Impl::convert_negative_shift(inout, shift.at(axis), axis);
shift0.at(axis) = _shift0;
shift1.at(axis) = _shift1;
shift2.at(axis) = _shift2;
}
int shift_00 = shift0.at(0), shift_10 = shift0.at(1);
int shift_01 = shift1.at(0), shift_11 = shift1.at(1);
int shift_02 = shift2.at(0), shift_12 = shift2.at(1);
Kokkos::parallel_for(
range, KOKKOS_LAMBDA(int i0, int i1) {
if (i0 + shift_00 < n0 && i1 + shift_10 < n1) {
tmp(i0 + shift_00, i1 + shift_10) = inout(i0, i1);
}
if (i0 + shift_01 < n0 && i1 + shift_11 < n1) {
tmp(i0, i1) = inout(i0 + shift_01, i1 + shift_11);
}
if (i0 + shift_01 < n0 && i1 + shift_10 < n1) {
tmp(i0 + shift_02, i1 + shift_10 + shift_12) =
inout(i0 + shift_01 + shift_02, i1 + shift_12);
}
if (i0 + shift_00 < n0 && i1 + shift_11 < n1) {
tmp(i0 + shift_00 + shift_02, i1 + shift_12) =
inout(i0 + shift_02, i1 + shift_11 + shift_12);
}
});
inout = tmp;
}
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void fftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
auto shift = get_shift(inout, axes);
roll(exec_space, inout, shift, axes);
}
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void ifftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
auto shift = get_shift(inout, axes, -1);
roll(exec_space, inout, shift, axes);
}
} // namespace Impl
} // namespace KokkosFFT
namespace KokkosFFT {
/// \brief Return the DFT sample frequencies
///
/// \param exec_space [in] Kokkos execution space
/// \param n [in] Window length
/// \param d [in] Sample spacing
///
/// \return Sampling frequency
template <typename ExecutionSpace, typename RealType>
auto fftfreq(const ExecutionSpace&, const std::size_t n,
const RealType d = 1.0) {
static_assert(KokkosFFT::Impl::is_real_v<RealType>,
"fftfreq: d must be float or double");
using ViewType = Kokkos::View<RealType*, ExecutionSpace>;
ViewType freq("freq", n);
RealType val = 1.0 / (static_cast<RealType>(n) * d);
int N1 = (n - 1) / 2 + 1;
int N2 = n / 2;
auto h_freq = Kokkos::create_mirror_view(freq);
auto p1 = KokkosFFT::Impl::arange(0, N1);
auto p2 = KokkosFFT::Impl::arange(-N2, 0);
for (int i = 0; i < N1; i++) {
h_freq(i) = static_cast<RealType>(p1.at(i)) * val;
}
for (int i = 0; i < N2; i++) {
h_freq(i + N1) = static_cast<RealType>(p2.at(i)) * val;
}
Kokkos::deep_copy(freq, h_freq);
return freq;
}
/// \brief Return the DFT sample frequencies for Real FFTs
///
/// \param exec_space [in] Kokkos execution space
/// \param n [in] Window length
/// \param d [in] Sample spacing
///
/// \return Sampling frequency starting from zero
template <typename ExecutionSpace, typename RealType>
auto rfftfreq(const ExecutionSpace&, const std::size_t n,
const RealType d = 1.0) {
static_assert(KokkosFFT::Impl::is_real_v<RealType>,
"fftfreq: d must be float or double");
using ViewType = Kokkos::View<RealType*, ExecutionSpace>;
RealType val = 1.0 / (static_cast<RealType>(n) * d);
int N = n / 2 + 1;
ViewType freq("freq", N);
auto h_freq = Kokkos::create_mirror_view(freq);
auto p = KokkosFFT::Impl::arange(0, N);
for (int i = 0; i < N; i++) {
h_freq(i) = static_cast<RealType>(p.at(i)) * val;
}
Kokkos::deep_copy(freq, h_freq);
return freq;
}
/// \brief Shift the zero-frequency component to the center of the spectrum
///
/// \param exec_space [in] Kokkos execution space
/// \param inout [in,out] Spectrum
/// \param axes [in] Axes over which to shift, optional
template <typename ExecutionSpace, typename ViewType>
void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
std::optional<int> axes = std::nullopt) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"fftshift: View value type must be float, double, "
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
static_assert(ViewType::rank() >= 1,
"fftshift: View rank must be larger than or equal to 1");
if (axes) {
axis_type<1> _axes{axes.value()};
KokkosFFT::Impl::fftshift_impl(exec_space, inout, _axes);
} else {
constexpr std::size_t rank = ViewType::rank();
constexpr int start = -static_cast<int>(rank);
axis_type<rank> _axes = KokkosFFT::Impl::index_sequence<rank>(start);
KokkosFFT::Impl::fftshift_impl(exec_space, inout, _axes);
}
}
/// \brief Shift the zero-frequency component to the center of the spectrum
///
/// \param exec_space [in] Kokkos execution space
/// \param inout [in,out] Spectrum
/// \param axes [in] Axes over which to shift
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"fftshift: View value type must be float, double, "
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"fftshift: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(ViewType::rank() >= DIM,
"fftshift: View rank must be larger than or equal to the Rank "
"of FFT axes");
KokkosFFT::Impl::fftshift_impl(exec_space, inout, axes);
}
/// \brief The inverse of fftshift
///
/// \param exec_space [in] Kokkos execution space
/// \param inout [in,out] Spectrum
/// \param axes [in] Axes over which to shift, optional
template <typename ExecutionSpace, typename ViewType>
void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
std::optional<int> axes = std::nullopt) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"ifftshift: View value type must be float, double, "
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
static_assert(ViewType::rank() >= 1,
"ifftshift: View rank must be larger than or equal to 1");
if (axes) {
axis_type<1> _axes{axes.value()};
KokkosFFT::Impl::ifftshift_impl(exec_space, inout, _axes);
} else {
constexpr std::size_t rank = ViewType::rank();
constexpr int start = -static_cast<int>(rank);
axis_type<rank> _axes = KokkosFFT::Impl::index_sequence<rank>(start);
KokkosFFT::Impl::ifftshift_impl(exec_space, inout, _axes);
}
}
/// \brief The inverse of fftshift
///
/// \param exec_space [in] Kokkos execution space
/// \param inout [in,out] Spectrum
/// \param axes [in] Axes over which to shift
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"ifftshift: View value type must be float, double, "
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"ifftshift: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(ViewType::rank() >= DIM,
"ifftshift: View rank must be larger than or equal to the Rank "
"of FFT axes");
KokkosFFT::Impl::ifftshift_impl(exec_space, inout, axes);
}
} // namespace KokkosFFT
#endif