diff --git a/btas/generic/gemm_impl.h b/btas/generic/gemm_impl.h index bbd3a357..b17f3b21 100644 --- a/btas/generic/gemm_impl.h +++ b/btas/generic/gemm_impl.h @@ -508,7 +508,7 @@ void gemm ( scal(beta, C); return; } - + typedef typename _TensorA::value_type value_type; assert(not ((transA == CblasConjTrans || transB == CblasConjTrans) && std::is_fundamental::value)); diff --git a/btas/generic/permute.h b/btas/generic/permute.h index 55ede85c..f50aadb3 100644 --- a/btas/generic/permute.h +++ b/btas/generic/permute.h @@ -21,8 +21,8 @@ namespace btas { is_boxtensor<_TensorY>::value >::type > - void - permute(const _TensorX& X, const _Permutation& p, _TensorY& Y) + void + permute(const _TensorX& X, const _Permutation& p, _TensorY& Y) { const auto pr = permute(X.range(),p); Y.resize(pr); diff --git a/btas/storage_traits.h b/btas/storage_traits.h index ebf0ea30..0c10ae6f 100644 --- a/btas/storage_traits.h +++ b/btas/storage_traits.h @@ -8,6 +8,8 @@ #ifndef BTAS_STORAGE_TRAITS_H_ #define BTAS_STORAGE_TRAITS_H_ +#include + #include namespace btas { @@ -43,6 +45,18 @@ namespace btas { typedef const_pointer const_iterator; }; + template + struct storage_traits> { + typedef _T value_type; + typedef _T& reference; + typedef const _T& const_reference; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + + typedef _T* iterator; + typedef typename std::add_const<_T*>::type const_iterator; + }; + template struct storage_traits { typedef typename _Storage::value_type value_type; diff --git a/btas/tensor.h b/btas/tensor.h index 35e68976..d86ab811 100644 --- a/btas/tensor.h +++ b/btas/tensor.h @@ -54,13 +54,13 @@ namespace btas { typedef const value_type& const_reference; /// element iterator - typedef typename storage_type::iterator iterator; + typedef typename storage_traits::iterator iterator; /// constant element iterator - typedef typename storage_type::const_iterator const_iterator; + typedef typename storage_traits::const_iterator const_iterator; /// size type - typedef typename storage_type::size_type size_type; + typedef typename storage_traits::size_type size_type; ///@} diff --git a/unittest/tensor_test.cc b/unittest/tensor_test.cc index cf118e35..3d0961f1 100644 --- a/unittest/tensor_test.cc +++ b/unittest/tensor_test.cc @@ -160,7 +160,35 @@ TEST_CASE("Tensor Constructors") } } -TEST_CASE("Tensor") +TEST_CASE("Custom Tensor") + { + + SECTION("Storage") + { + { + typedef Tensor > Tensor; + Tensor T0; + Tensor T1(2,3,4); + } + { + typedef Tensor > Tensor; + Tensor T0; + Tensor T1(2,3,4); + } + { + typedef Tensor > Tensor; + Tensor T0; + Tensor T1(2,3,4); + } + { + typedef Tensor > Tensor; + Tensor T0; + Tensor T1(2,3,4); + } + } + } + +TEST_CASE("Tensor Operations") { DTensor T2(3,2); fillEls(T2);