Skip to content

Commit a282aa8

Browse files
committed
Added type trait class, for flexible indexer function
1 parent cceed55 commit a282aa8

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

TypeTensor.hpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ namespace TTDimK {
8181
}
8282
};
8383
// ---------------------------------------------------------- Digger by objects --------------
84+
template <typename CLS>
85+
struct TensorIndexerTrait { };
86+
87+
template <typename CLS>
88+
struct TensorIndexerTrait<CLS*> {
89+
static size_t GetIndex(CLS* obj) { return obj->TensorIdx(); }
90+
};
91+
8492
template < typename Data, typename ... ParamTail >
8593
struct ObjDigger {
8694
template <typename ArrayT>
@@ -90,7 +98,7 @@ namespace TTDimK {
9098
struct ObjDigger<Data, ParamHead, ParamTail... > {
9199
template <typename ArrayT>
92100
static Data& Dig(ArrayT& arr, ParamHead ph, ParamTail... rest) {
93-
size_t idx = ph->TensorIdx(); // need to be generalized !!!
101+
size_t idx = TensorIndexerTrait<ParamHead>::GetIndex(ph); // ph->TensorIdx(); // need to be generalized !!!
94102
return ObjDigger<Data, ParamTail...>::Dig(arr[idx], rest...);
95103
}
96104
};

TypeTensorDemo.cpp

+19-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
/// - store integers by static type, retrieve by dynamic type
44
/// - Double Dispatch
55
/// - Triple Dispatch
6+
/// Quick and dirty test.
67
/// </summary>
78

89
#include <iostream>
@@ -36,8 +37,7 @@ typedef TTDimK::TypeList<A0, A1, A2> ALIST;
3637
typedef TTDimK::TypeList<B0, B1, B2, B3> BLIST;
3738
typedef TTDimK::TypeList<C0, C1, C2, C3> CLIST;
3839

39-
struct A0 {
40-
40+
struct A0 {
4141
// tensor index of class, depends only typelist, need not change when inheritance changed!
4242
virtual size_t TensorIdx() { return TTDimK::TLIdx<ALIST, A0>(); }
4343

@@ -50,7 +50,7 @@ struct B0 {
5050
static string StaticTypeName() { return "B0"; }
5151
};
5252
struct C0 {
53-
virtual size_t TensorIdx() { return TTDimK::TLIdx<CLIST, C0>(); }
53+
virtual size_t TensorIdxOther() { return TTDimK::TLIdx<CLIST, C0>(); }
5454
static string StaticTypeName() { return "C0"; }
5555
};
5656

@@ -68,10 +68,23 @@ CreateInherited(B0, B1, BLIST)
6868
CreateInherited(B1, B2, BLIST)
6969
CreateInherited(B2, B3, BLIST)
7070

71-
CreateInherited(C0, C1, CLIST)
72-
CreateInherited(C1, C2, CLIST)
73-
CreateInherited(C2, C3, CLIST)
71+
// ------------- with different indexer function ------------------------------
72+
#define CreateInheritedOther(BASE, ACTUAL, LIST) \
73+
struct ACTUAL : public BASE { \
74+
virtual size_t TensorIdxOther() { return TTDimK::TLIdx<LIST, ACTUAL>(); } \
75+
static string StaticTypeName() { return #ACTUAL; } \
76+
};
7477

78+
CreateInheritedOther(C0, C1, CLIST)
79+
CreateInheritedOther(C1, C2, CLIST)
80+
CreateInheritedOther(C2, C3, CLIST)
81+
82+
namespace TTDimK {
83+
template <>
84+
struct TensorIndexerTrait<C0*> {
85+
static size_t GetIndex(C0* obj) { return obj->TensorIdxOther(); }
86+
};
87+
}
7588
// --------------------------------------------- main entry
7689
int main()
7790
{

0 commit comments

Comments
 (0)