-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy patharray.cpp
113 lines (71 loc) · 2.86 KB
/
array.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause
#include "ginkgo/core/base/array.hpp"
#include <type_traits>
#include <ginkgo/core/base/math.hpp>
#include "core/base/array_access.hpp"
#include "core/components/fill_array_kernels.hpp"
#include "core/components/precision_conversion_kernels.hpp"
#include "core/components/reduce_array_kernels.hpp"
namespace gko {
namespace conversion {
namespace {
GKO_REGISTER_OPERATION(convert, components::convert_precision);
} // anonymous namespace
} // namespace conversion
namespace array_kernels {
namespace {
GKO_REGISTER_OPERATION(fill_array, components::fill_array);
GKO_REGISTER_OPERATION(reduce_add_array, components::reduce_add_array);
} // anonymous namespace
} // namespace array_kernels
namespace detail {
template <typename SourceType, typename TargetType>
void convert_data(std::shared_ptr<const Executor> exec, size_type size,
const SourceType* src, TargetType* dst)
{
exec->run(conversion::make_convert(size, src, dst));
}
#define GKO_DECLARE_ARRAY_CONVERSION(From, To) \
void convert_data<From, To>(std::shared_ptr<const Executor>, size_type, \
const From*, To*)
GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION(GKO_DECLARE_ARRAY_CONVERSION);
} // namespace detail
template <typename ValueType>
void array<ValueType>::fill(const ValueType value)
{
this->get_executor()->run(array_kernels::make_fill_array(
this->get_data(), this->get_size(), value));
}
template <typename ValueType>
void reduce_add(const array<ValueType>& input_arr, array<ValueType>& result)
{
GKO_ASSERT(result.get_size() == 1);
auto exec = input_arr.get_executor();
exec->run(array_kernels::make_reduce_add_array(input_arr, result));
}
template <typename ValueType>
ValueType reduce_add(const array<ValueType>& input_arr,
const ValueType init_value)
{
auto exec = input_arr.get_executor();
auto value = array<ValueType>(exec, 1);
value.fill(ValueType{0});
exec->run(array_kernels::make_reduce_add_array(input_arr, value));
return init_value + get_element(value, 0);
}
#define GKO_DECLARE_ARRAY_FILL(_type) void array<_type>::fill(const _type value)
GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_ARRAY_FILL);
template GKO_DECLARE_ARRAY_FILL(uint16);
template GKO_DECLARE_ARRAY_FILL(uint32);
#ifndef GKO_SIZE_T_IS_UINT64_T
template GKO_DECLARE_ARRAY_FILL(uint64);
#endif
#define GKO_DECLARE_ARRAY_REDUCE_ADD(_type) \
void reduce_add(const array<_type>& arr, array<_type>& value)
GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_ARRAY_REDUCE_ADD);
#define GKO_DECLARE_ARRAY_REDUCE_ADD2(_type) \
_type reduce_add(const array<_type>& arr, const _type val)
GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_ARRAY_REDUCE_ADD2);
} // namespace gko