-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTypeTensor.hpp
134 lines (111 loc) · 4.8 KB
/
TypeTensor.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#pragma once
#include <array>
/// <summary>
/// TypeTensor K dimensions.
///
/// If K=2 and stored data is a function, then it is another DoubleDispatch implementation ("original" with visitor pattern)
///
/// If Data is a function then K parametered dispatch (Triple and more)
/// It is a matrix N1xN2x...xNK, which is created in compile time.
/// The matrix's axes are types, every element can be indexed by types.
/// Element of the matrix can be anything.
/// If it is a function for example, which is called when dynamic type of object of axis type points to that.
/// Author: Tamas Orosz 2020
/// 2DO:
/// - generalize indexer virtual function with functor or traits (now fix name and signature: TensorIdx)
/// - maybe possible add constexpr to some places...
/// - hide internal templates from "interface" (place to inside typetensor?)
/// - parameters checks...
/// </summary>
namespace TTDimK {
// ---------------------------------------------------------- TypeList, helper --------------
template <typename ... T>
struct TypeList {
constexpr static size_t RevIdx() { return 0; }
constexpr static size_t Len() { return 0; }
static void Call() { }
};
template <typename H, typename... T >
struct TypeList<H, T...> : TypeList<T...>
{
using Head = H;
using Tail = TypeList<T...>;
constexpr static size_t Len() { return 1 + Tail::Len(); }
using Tail::RevIdx;
constexpr static size_t RevIdx(Head*) {
return Len();
}
static void Call() {
Head::ClassName();
Tail::Call();
}
};
// ----------------------------------------------------------- Get index of a class in typelist ---------
template <typename TypeList, typename Type>
constexpr size_t TLIdx(void) { return TypeList::Len() - TypeList::RevIdx((Type*)nullptr); }
// ---------------------------------------------------------- builder of typetensor's raw array data --------------
template <typename Data, typename ... AxisTail >
struct Builder {
using ArrayT = Data;
};
template <typename Data, typename AxisHead, typename ... AxisTail >
struct Builder<Data, AxisHead, AxisTail...>
{
using ArrayT = std::array< typename Builder<Data, AxisTail...>::ArrayT, AxisHead::Len() >;
};
// ---------------------------------------------------------- Digger by type --------------
template < typename Data, typename ParamList, typename ... AxisTail >
struct TypeDigger {
template <typename ArrayT>
static Data& Dig(ArrayT& arr) { return arr; }
};
template <typename Data, typename ParamList, typename AxisHead, typename ... AxisTail >
struct TypeDigger<Data, ParamList, AxisHead, AxisTail ... > {
template <typename ArrayT>
static Data& Dig(ArrayT& arr) {
constexpr size_t idx = TLIdx<AxisHead, typename ParamList::Head>();
return TypeDigger<Data, typename ParamList::Tail, AxisTail...>::Dig(arr[idx]);
}
};
// ---------------------------------------------------------- Digger by objects --------------
template <typename CLS>
struct TensorIndexerTrait { };
template <typename CLS>
struct TensorIndexerTrait<CLS*> {
static size_t GetIndex(CLS* obj) { return obj->TensorIdx(); }
};
template < typename Data, typename ... ParamTail >
struct ObjDigger {
template <typename ArrayT>
static Data& Dig(ArrayT& arr) { return arr; }
};
template <typename Data, typename ParamHead, typename ... ParamTail>
struct ObjDigger<Data, ParamHead, ParamTail... > {
template <typename ArrayT>
static Data& Dig(ArrayT& arr, ParamHead ph, ParamTail... rest) {
size_t idx = TensorIndexerTrait<ParamHead>::GetIndex(ph); // ph->TensorIdx(); // need to be generalized !!!
return ObjDigger<Data, ParamTail...>::Dig(arr[idx], rest...);
}
};
// ---------------------------------------------------------- The tensor --------------
template <typename Data, typename ... AllAxes >
struct TypeTensor
{
using ArrayT = typename Builder<Data, AllAxes...>::ArrayT;
inline static ArrayT arr;
using MyData = Data;
template <typename ... Params>
static Data& at() {
return TypeDigger<Data, TypeList<Params...>, AllAxes...>::Dig(arr);
}
template <typename ... Args>
static Data& at(Args ... params) {
return ObjDigger<Data, Args...>::Dig(arr, params...);
}
template <typename ... Args>
static auto Call(Args ... params) {
return (ObjDigger<Data, Args...>::Dig(arr, params...))(params...);
}
};
// -----------------------------------------------------------------
} // namespace