Skip to content

Commit 087a7ee

Browse files
committed
dot, dotc: add customization points and tests kokkos#96
1 parent 7d3d04a commit 087a7ee

File tree

5 files changed

+79
-37
lines changed

5 files changed

+79
-37
lines changed

examples/kokkos-based/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

2-
linalg_add_example(dot_kokkos)
32
linalg_add_example(add_kokkos)
3+
linalg_add_example(dot_kokkos)
4+
linalg_add_example(dotc_kokkos)
45
linalg_add_example(simple_scale_kokkos)
56
linalg_add_example(matrix_vector_product_kokkos)

examples/kokkos-based/add_kokkos.cpp

+16-13
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,37 @@ void print_elements(const T1 & v, const std::vector<ScalarType> & gold)
1515

1616
int main(int argc, char* argv[])
1717
{
18-
std::cout << "running add example calling custom kokkos" << std::endl;
19-
int N = 50;
18+
std::cout << "add example: calling kokkos-kernels" << std::endl;
19+
20+
std::size_t N = 50;
2021
Kokkos::initialize(argc,argv);
2122
{
22-
Kokkos::View<double*> x_view("x",N);
23-
Kokkos::View<double*> y_view("y",N);
24-
Kokkos::View<double*> z_view("z",N);
23+
using value_type = double;
24+
25+
Kokkos::View<value_type*> x_view("x",N);
26+
Kokkos::View<value_type*> y_view("y",N);
27+
Kokkos::View<value_type*> z_view("z",N);
2528

26-
double* x_ptr = x_view.data();
27-
double* y_ptr = y_view.data();
28-
double* z_ptr = z_view.data();
29+
value_type* x_ptr = x_view.data();
30+
value_type* y_ptr = y_view.data();
31+
value_type* z_ptr = z_view.data();
2932

3033
using dyn_1d_ext_type = std::experimental::extents<std::experimental::dynamic_extent>;
31-
using mdspan_type = std::experimental::mdspan<double, dyn_1d_ext_type>;
34+
using mdspan_type = std::experimental::mdspan<value_type, dyn_1d_ext_type>;
3235
mdspan_type x(x_ptr,N);
3336
mdspan_type y(y_ptr,N);
3437
mdspan_type z(z_ptr,N);
3538

36-
std::vector<double> gold(N);
39+
std::vector<value_type> gold(N);
3740
for(int i=0; i<x.extent(0); i++){
3841
x(i) = i;
39-
y(i) = i + (double)10;
42+
y(i) = i + (value_type)10;
4043
z(i) = 0;
4144
gold[i] = x(i) + y(i);
4245
}
4346

4447
namespace stdla = std::experimental::linalg;
45-
const double init_value = 2.0;
48+
const value_type init_value = 2.0;
4649

4750
{
4851
// This goes to the base implementation
@@ -51,7 +54,7 @@ int main(int argc, char* argv[])
5154

5255
{
5356
// reset z since it is modified above
54-
for(int i=0; i<z.extent(0); i++){ z(i) = 0; }
57+
for(std::size_t i=0; i<z.extent(0); i++){ z(i) = 0; }
5558

5659
// This forwards to KokkosKernels
5760
stdla::add(KokkosKernelsSTD::kokkos_exec<>(), x,y,z);

examples/kokkos-based/dot_kokkos.cpp

+13-11
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,37 @@
44

55
int main(int argc, char* argv[])
66
{
7-
std::cout << "running dot example calling custom kokkos" << std::endl;
8-
int N = 50;
7+
std::cout << "dot example: calling kokkos-kernels" << std::endl;
8+
9+
std::size_t N = 50;
910
Kokkos::initialize(argc,argv);
1011
{
11-
Kokkos::View<double*> a_view("A",N);
12-
Kokkos::View<double*> b_view("B",N);
13-
double* a_ptr = a_view.data();
14-
double* b_ptr = b_view.data();
12+
using value_type = double;
13+
14+
Kokkos::View<value_type*> a_view("A",N);
15+
Kokkos::View<value_type*> b_view("B",N);
16+
value_type* a_ptr = a_view.data();
17+
value_type* b_ptr = b_view.data();
1518

1619
using dyn_1d_ext_type = std::experimental::extents<std::experimental::dynamic_extent>;
17-
using mdspan_type = std::experimental::mdspan<double, dyn_1d_ext_type>;
20+
using mdspan_type = std::experimental::mdspan<value_type, dyn_1d_ext_type>;
1821
mdspan_type a(a_ptr,N);
1922
mdspan_type b(b_ptr,N);
20-
for(int i=0; i<a.extent(0); i++){
23+
for(std::size_t i=0; i<a.extent(0); i++){
2124
a(i) = i;
2225
b(i) = i;
2326
}
2427

2528
namespace stdla = std::experimental::linalg;
26-
const double init_value = 2.0;
29+
const value_type init_value(2.0);
2730

2831
// This goes to the base implementation
2932
const auto res_seq = stdla::dot(std::execution::seq, a, b, init_value);
33+
printf("Seq result = %lf\n", res_seq);
3034

3135
// This forwards to KokkosKernels
3236
auto res_kk = stdla::dot(KokkosKernelsSTD::kokkos_exec<>(), a, b, init_value);
33-
3437
printf("Kokkos result = %lf\n", res_kk);
35-
printf("Seq result = %lf\n", res_seq);
3638
}
3739
Kokkos::finalize();
3840
}

examples/kokkos-based/dotc_kokkos.cpp

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
2+
#include <experimental/linalg>
3+
#include <iostream>
4+
5+
int main(int argc, char* argv[])
6+
{
7+
std::cout << "dotc example: calling kokkos-kernels" << std::endl;
8+
9+
std::size_t N = 10;
10+
Kokkos::initialize(argc,argv);
11+
{
12+
using value_type = std::complex<double>;
13+
using view_t = Kokkos::View<value_type*>;
14+
view_t a_view("A",N);
15+
view_t b_view("B",N);
16+
value_type* a_ptr = a_view.data();
17+
value_type* b_ptr = b_view.data();
18+
19+
using dyn_1d_ext_type = std::experimental::extents<std::experimental::dynamic_extent>;
20+
using mdspan_type = std::experimental::mdspan<value_type, dyn_1d_ext_type>;
21+
mdspan_type a(a_ptr,N);
22+
mdspan_type b(b_ptr,N);
23+
for(std::size_t i=0; i<a.extent(0); i++){
24+
const value_type a_i(double(i) + 1.0, double(i) + 1.0);
25+
const value_type b_i(double(i) - 2.0, double(i) - 2.0);
26+
a(i) = a_i;
27+
b(i) = b_i;
28+
}
29+
30+
namespace stdla = std::experimental::linalg;
31+
const value_type init_value(2., 3.);
32+
33+
// This goes to the base implementation
34+
const auto res_seq = stdla::dotc(std::execution::seq, a, b, init_value);
35+
std::cout << "Seq result = " << res_seq << "\n";
36+
37+
// This forwards to KokkosKernels
38+
auto res_kk = stdla::dotc(KokkosKernelsSTD::kokkos_exec<>(), a, b, init_value);
39+
std::cout << "Kokkos result = " << res_kk << "\n";
40+
}
41+
Kokkos::finalize();
42+
}

include/experimental/__p1673_bits/blas1_dot.hpp

+6-12
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,6 @@ struct is_custom_dot_avail<
7878

7979
} // end anonymous namespace
8080

81-
82-
// ------------
83-
// PUBLIC API:
84-
// ------------
85-
86-
// dot, with init value
8781
template<class ExecutionPolicy,
8882
class ElementType1,
8983
extents<>::size_type ext1,
@@ -100,6 +94,9 @@ Scalar dot(
10094
std::experimental::mdspan<ElementType2, std::experimental::extents<ext2>, Layout2, Accessor2> v2,
10195
Scalar init)
10296
{
97+
static_assert(v1.static_extent(0) == dynamic_extent ||
98+
v2.static_extent(0) == dynamic_extent ||
99+
v1.static_extent(0) == v2.static_extent(0));
103100

104101
constexpr bool use_custom = is_custom_dot_avail<
105102
decltype(execpolicy_mapper(exec)), decltype(v1), decltype(v2), Scalar
@@ -132,7 +129,6 @@ Scalar dot(std::experimental::mdspan<ElementType1, std::experimental::extents<ex
132129
return dot(std::experimental::linalg::impl::default_exec_t(), v1, v2, init);
133130
}
134131

135-
// Conjugated dot, with init value
136132
template<class ElementType1,
137133
extents<>::size_type ext1,
138134
class Layout1,
@@ -150,7 +146,6 @@ Scalar dotc(
150146
return dot(conjugated(v1), v2, init);
151147
}
152148

153-
// conjugated dot: with policy, with init value
154149
template<class ExecutionPolicy,
155150
class ElementType1,
156151
extents<>::size_type ext1,
@@ -162,12 +157,12 @@ template<class ExecutionPolicy,
162157
class Accessor2,
163158
class Scalar>
164159
Scalar dotc(
165-
ExecutionPolicy&& /* exec */,
160+
ExecutionPolicy&& exec,
166161
std::experimental::mdspan<ElementType1, std::experimental::extents<ext1>, Layout1, Accessor1> v1,
167162
std::experimental::mdspan<ElementType2, std::experimental::extents<ext2>, Layout2, Accessor2> v2,
168163
Scalar init)
169164
{
170-
return dotc(v1, v2, init);
165+
return dot(exec, conjugated(v1), v2, init);
171166
}
172167

173168
namespace dot_detail {
@@ -190,7 +185,6 @@ namespace dot_detail {
190185
-> decltype(x(0) * y(0));
191186
} // namespace dot_detail
192187

193-
// dot, without init value
194188
template<class ElementType1,
195189
extents<>::size_type ext1,
196190
class Layout1,
@@ -240,7 +234,7 @@ auto dotc(
240234
std::experimental::mdspan<ElementType2, std::experimental::extents<ext2>, Layout2, Accessor2> v2)
241235
-> decltype(dot_detail::dot_return_type_deducer(conjugated(v1), v2))
242236
{
243-
using return_t = decltype(dot_detail::dot_return_type_deducer(v1, v2));
237+
using return_t = decltype(dot_detail::dot_return_type_deducer(conjugated(v1), v2));
244238
return dotc(v1, v2, return_t{});
245239
}
246240

0 commit comments

Comments
 (0)