From 336840215052826b75ad9b152183084dd35bbcd0 Mon Sep 17 00:00:00 2001 From: Cliff Burdick <30670611+cliffburdick@users.noreply.github.com> Date: Fri, 22 Nov 2024 19:54:07 -0800 Subject: [PATCH] Fixing broadcasting in all operator() (#795) --- include/matx/core/tensor_utils.h | 60 +++++++++++++++++---- include/matx/operators/at.h | 2 +- include/matx/operators/cast.h | 4 +- include/matx/operators/clone.h | 20 ++++--- include/matx/operators/collapse.h | 40 +++++++++----- include/matx/operators/comma.h | 8 +-- include/matx/operators/concat.h | 75 +++++++++++++------------- include/matx/operators/diag.h | 18 ++++--- include/matx/operators/fftshift.h | 18 ++++--- include/matx/operators/hermitian.h | 12 ++--- include/matx/operators/interleaved.h | 15 ++---- include/matx/operators/isclose.h | 4 +- include/matx/operators/kronecker.h | 28 +++++----- include/matx/operators/legendre.h | 78 +++++++++++++-------------- include/matx/operators/overlap.h | 10 +++- include/matx/operators/permute.h | 23 ++++---- include/matx/operators/planar.h | 17 +++--- include/matx/operators/polyval.h | 2 +- include/matx/operators/r2c.h | 22 +++----- include/matx/operators/remap.h | 27 ++++++---- include/matx/operators/repmat.h | 64 +++++++++++----------- include/matx/operators/reshape.h | 24 +++++---- include/matx/operators/reverse.h | 23 ++++---- include/matx/operators/select.h | 12 +++-- include/matx/operators/shift.h | 56 ++++++++++++------- include/matx/operators/slice.h | 29 ++++++---- include/matx/operators/stack.h | 50 ++++++++--------- include/matx/operators/toeplitz.h | 4 +- include/matx/operators/updownsample.h | 2 +- 29 files changed, 427 insertions(+), 320 deletions(-) diff --git a/include/matx/core/tensor_utils.h b/include/matx/core/tensor_utils.h index 6398930a..3d534188 100644 --- a/include/matx/core/tensor_utils.h +++ b/include/matx/core/tensor_utils.h @@ -33,6 +33,7 @@ #pragma once #include +#include #include #include "matx/core/nvtx.h" @@ -245,31 +246,56 @@ namespace matx * @param indices indices * @return Value after broadcasting */ - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto get_matx_value(const T &i, Is... indices) + template ...>, bool> = true> + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_matx_value(T &&i, Is... indices) { - if constexpr (T::Rank() == int(sizeof...(Is)) || T::Rank() == matxNoRank) { - return i(indices...); + constexpr int RANK = remove_cvref_t::Rank(); + if constexpr (RANK == int(sizeof...(Is)) || RANK == matxNoRank) { + // If we're only indexing with the same number of arguments as the rank of the operator, just return operator() + return cuda::std::forward(i)(indices...); } else { - // Construct an integer sequence of the length of the tuple, but only using the last indices - using seq = offset_sequence_t>; + // Otherwise we need to broadcast by constructing a large set of indices + // Construct an integer sequence of the length of the tuple, but only using the last indices. We construct an offset sequence + // to index into the broadcasted dimensions. For example, if T is a 3D tensor and we want to index as a 5D, we take the indices + // {0, 1, 2} we'd normally index with, and add the difference in rank (2), to get {2, 3, 4}. Another way to think of this is it + // simply chops off the first sizeof...(Is) - RANK indices since they're not used for operator(). + using seq = offset_sequence_t>; auto tup = cuda::std::make_tuple(indices...); auto sliced_tup = select_tuple(std::forward(tup), seq{}); return cuda::std::apply([&](auto... args) { - return i(args...); + return cuda::std::forward(i)(args...); }, sliced_tup); } } + template + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_matx_value(T &&i, const cuda::std::array idx) + { + constexpr int RANK = remove_cvref_t::Rank(); + if constexpr (RANK == N || RANK == matxNoRank) { + // If we're only indexing with the same number of arguments as the rank of the operator, just return operator() + return cuda::std::apply(cuda::std::forward(i), idx); + //return i(indices...); + } + else + { + cuda::std::array nbc_idx; // non-broadcast indices + cuda::std::copy(idx.begin() + (N - RANK), idx.end(), nbc_idx.begin()); + return cuda::std::apply([&](auto... args) { + return cuda::std::forward(i)(args...); + }, nbc_idx); + } + } + - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto get_value(const T &i, Is... indices) + template ...>, bool> = true> + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_value(T &&i, Is... indices) { if constexpr (is_matx_op()) { - return get_matx_value(i, indices...); + return get_matx_value(cuda::std::forward(i), indices...); } else { @@ -277,6 +303,20 @@ namespace matx } } + + template + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_value(T &&i, const cuda::std::array idx) + { + if constexpr (is_matx_op()) + { + return get_matx_value(cuda::std::forward(i), idx); + } + else + { + return i; + } + } + template __MATX_INLINE__ std::string to_short_str() { if constexpr (!is_complex_v) { if constexpr (std::is_same_v) diff --git a/include/matx/operators/at.h b/include/matx/operators/at.h index 910c0def..e88cc78a 100644 --- a/include/matx/operators/at.h +++ b/include/matx/operators/at.h @@ -58,7 +58,7 @@ namespace matx template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()([[maybe_unused]] Is2... indices) const { - return cuda::std::apply(op_, idx_); + return get_value(op_, idx_); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() diff --git a/include/matx/operators/cast.h b/include/matx/operators/cast.h index 83a1aa4d..c1af0708 100644 --- a/include/matx/operators/cast.h +++ b/include/matx/operators/cast.h @@ -74,13 +74,13 @@ namespace matx template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const { - return static_cast(op_(indices...)); + return static_cast(get_value(op_, indices...)); } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) { - return static_cast(op_(indices...)); + return static_cast(get_value(op_, indices...)); } template diff --git a/include/matx/operators/clone.h b/include/matx/operators/clone.h index 47b77cdf..85c614d8 100644 --- a/include/matx/operators/clone.h +++ b/include/matx/operators/clone.h @@ -85,11 +85,9 @@ IGNORE_WARNING_POP_GCC } - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const - { - - // convert variadic type to tuple so we can read/update + template + static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(Op&& op, const Dims &dims, Is... indices) + { IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized") cuda::std::array sind{indices...}; cuda::std::array gind; @@ -97,17 +95,23 @@ IGNORE_WARNING_POP_GCC // gather indices for(int i = 0; i < T::Rank(); i++) { - auto idx = dims_[i]; + auto idx = dims[i]; gind[i] = sind[idx]; } - return cuda::std::apply(op_, gind); + return get_value(cuda::std::forward(op), gind); + } + + template + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const + { + return get_impl(cuda::std::as_const(op_), dims_, indices...); } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) { - return cuda::std::as_const(*this).template operator()(indices...); + return get_impl(cuda::std::forward(op_), dims_, indices...); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() diff --git a/include/matx/operators/collapse.h b/include/matx/operators/collapse.h index e83f9e1a..4330a549 100644 --- a/include/matx/operators/collapse.h +++ b/include/matx/operators/collapse.h @@ -70,8 +70,8 @@ namespace matx } } - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const + template + static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(Op&& op, Is... indices) { // indices coming in cuda::std::array in{indices...}; // index coming in @@ -88,17 +88,23 @@ namespace matx #pragma unroll for(int i = 0; i < DIM; i++) { int d = DIM - i - 1; - out[d] = ind % op_.Size(d); - ind /= op_.Size(d); + out[d] = ind % op.Size(d); + ind /= op.Size(d); } - return cuda::std::apply(op_, out); + return get_value(cuda::std::forward(op), out); + } + + template + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const + { + return get_impl(cuda::std::as_const(op_), indices...); } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) { - return cuda::std::as_const(*this).template operator()(indices...); + return get_impl(cuda::std::forward(op_), indices...); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() @@ -207,9 +213,9 @@ namespace matx } } - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const - { + template + static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(Op&& op, Is... indices) + { // indices coming in cuda::std::array in{indices...}; // index coming in cuda::std::array out; // index going out @@ -225,18 +231,24 @@ namespace matx #pragma unroll for(int i = 0; i < DIM; i++) { int d = T1::Rank() - 1 - i; - out[d] = ind % op_.Size(d); - ind /= op_.Size(d); + out[d] = ind % op.Size(d); + ind /= op.Size(d); } - return cuda::std::apply(op_, out); + return get_value(cuda::std::forward(op), out); + } + + template + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const + { + return get_impl(cuda::std::as_const(op_), indices...); } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) { - return cuda::std::as_const(*this).template operator()(indices...); - } + return get_impl(cuda::std::forward(op_), indices...); + } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() { diff --git a/include/matx/operators/comma.h b/include/matx/operators/comma.h index 519abf45..4896e0f4 100644 --- a/include/matx/operators/comma.h +++ b/include/matx/operators/comma.h @@ -61,10 +61,10 @@ namespace matx __MATX_INLINE__ std::string str() const { return op1_.str() + ", " + op2_.str(); } template - auto __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ operator()(Is... indices) const { - op1_(indices...); - return op2_(indices...); - } + auto __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ operator()(Is... indices) const { + get_value(op1_, indices...); + return get_value(op2_, indices...); + } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() noexcept { diff --git a/include/matx/operators/concat.h b/include/matx/operators/concat.h index 73bd364b..50cd4636 100644 --- a/include/matx/operators/concat.h +++ b/include/matx/operators/concat.h @@ -91,52 +91,55 @@ namespace matx } } + + + template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto GetVal(cuda::std::array &indices) const { + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto GetVal(cuda::std::array &indices) const { - if constexpr ( I == N ) { - // This should never happen - return value_type{}; - // returning this to satisfy lvalue requirements + if constexpr ( I == N ) { + // This should never happen + return value_type{}; + // returning this to satisfy lvalue requirements + } else { + const auto &op = cuda::std::get(ops_); + auto idx = indices[axis_]; + auto size = op.Size(axis_); + // If in range of this operator + if(idx < size) { + // evaluate operator + return get_value(cuda::std::forward(op), indices); } else { - const auto &op = cuda::std::get(ops_); - auto idx = indices[axis_]; - auto size = op.Size(axis_); - // If in range of this operator - if(idx < size) { - // evaluate operator - return cuda::std::apply(op, indices); - } else { - // otherwise remove this operator and recurse - indices[axis_] -= size; - return GetVal(indices); - } + // otherwise remove this operator and recurse + indices[axis_] -= size; + return GetVal(indices); } } - + } + template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) GetVal(cuda::std::array &indices) { + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) GetVal(cuda::std::array &indices) { - if constexpr ( I == N ) { - // This should never happen - // returning this to satisfy lvalue requirements - auto &op = cuda::std::get(ops_); - return cuda::std::apply(op, indices); + if constexpr ( I == N ) { + // This should never happen + // returning this to satisfy lvalue requirements + auto &op = cuda::std::get(ops_); + return get_value(cuda::std::forward(op), indices); + } else { + auto &op = cuda::std::get(ops_); + auto idx = indices[axis_]; + auto size = op.Size(axis_); + // If in range of this operator + if(idx < size) { + // evaluate operator + return get_value(cuda::std::forward(op), indices); } else { - auto &op = cuda::std::get(ops_); - auto idx = indices[axis_]; - auto size = op.Size(axis_); - // If in range of this operator - if(idx < size) { - // evaluate operator - return cuda::std::apply(op, indices); - } else { - // otherwise remove this operator and recurse - indices[axis_] -= size; - return GetVal(indices); - } + // otherwise remove this operator and recurse + indices[axis_] -= size; + return GetVal(indices); } } + } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... is) const diff --git a/include/matx/operators/diag.h b/include/matx/operators/diag.h index 2f08e04a..08c73764 100644 --- a/include/matx/operators/diag.h +++ b/include/matx/operators/diag.h @@ -82,19 +82,23 @@ namespace matx // Offset either the rows or columns by k_, depending on if it's negative if (k_ < 0) { - auto tup = cuda::std::make_tuple(indices..., static_cast(0)); - cuda::std::get(tup) = pp_get(indices...) ; + cuda::std::array tmp{indices...}; + tmp[RANK - 1] = pp_get(indices...); + //cuda::std::get(tup) = pp_get(indices...) ; IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized") - cuda::std::get(tup) = cuda::std::get(tup) - k_; + tmp[RANK - 2] -= k_; + //cuda::std::get(tup) = cuda::std::get(tup) - k_; IGNORE_WARNING_POP_GCC - return cuda::std::apply(op_, tup); + return get_value(op_, tmp); } else { - auto tup = cuda::std::make_tuple(indices..., static_cast(0)); + cuda::std::array tmp{indices...}; + //auto tup = cuda::std::make_tuple(indices..., static_cast(0)); IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized") - cuda::std::get(tup) = pp_get(indices...) + k_; + tmp[RANK - 1] = pp_get(indices...) + k_; + //cuda::std::get(tup) = pp_get(indices...) + k_; IGNORE_WARNING_POP_GCC - return cuda::std::apply(op_, tup); + return get_value(op_, tmp); } } } diff --git a/include/matx/operators/fftshift.h b/include/matx/operators/fftshift.h index e089f740..df475d6a 100644 --- a/include/matx/operators/fftshift.h +++ b/include/matx/operators/fftshift.h @@ -55,19 +55,25 @@ namespace matx static_assert(Rank() >= 1, "1D FFT shift must have a rank 1 operator or higher"); }; + template + static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(Op&& op, Is... indices) + { + cuda::std::array idx{indices...}; + idx[Rank() - 1] = (idx[Rank() - 1] + (op.Size(Rank()-1) + 1) / 2) % op.Size(Rank()-1); + return get_value(cuda::std::forward(op), idx); + } + template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const { - auto tup = cuda::std::make_tuple(indices...); - cuda::std::get(tup) = (cuda::std::get(tup) + (Size(Rank()-1) + 1) / 2) % Size(Rank()-1); - return cuda::std::apply(op_, tup); - } + return get_impl(cuda::std::as_const(op_), indices...); + } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) { - return cuda::std::as_const(*this).template operator()(indices...); - } + return get_impl(cuda::std::forward(op_), indices...); + } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() { diff --git a/include/matx/operators/hermitian.h b/include/matx/operators/hermitian.h index 85df11cf..cc3eba3e 100644 --- a/include/matx/operators/hermitian.h +++ b/include/matx/operators/hermitian.h @@ -32,7 +32,7 @@ #pragma once - +#include #include "matx/core/type_utils.h" #include "matx/operators/base_operator.h" @@ -55,7 +55,7 @@ namespace matx using matxop = bool; using value_type = typename T1::value_type; - __MATX_INLINE__ std::string str() const { return "hermitian(" + op_.str() + ")"; } + __MATX_INLINE__ std::string str() const { return "hermitian(" + op_.str() + ")"; } __MATX_INLINE__ HermitianTransOp(const T1 &op) : op_(op) { static_assert(Rank() >= 2, "Hermitian operation needs input with rank >= 2"); } @@ -63,11 +63,9 @@ namespace matx template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const { - auto tup = cuda::std::make_tuple(indices...); - auto stl = cuda::std::get(tup); - cuda::std::get(tup) = cuda::std::get(tup); - cuda::std::get(tup) = stl; - return conj(cuda::std::apply(op_, tup)); + cuda::std::array idx{indices...}; + cuda::std::swap(idx[Rank() - 2], idx[Rank() - 1]); + return conj(get_value(op_, idx)); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() diff --git a/include/matx/operators/interleaved.h b/include/matx/operators/interleaved.h index 921a3925..ce293615 100644 --- a/include/matx/operators/interleaved.h +++ b/include/matx/operators/interleaved.h @@ -58,26 +58,21 @@ namespace matx static_assert(!is_complex_v>, "Complex interleaved op only works on scalar input types"); static_assert(Rank() > 0); }; + template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ complex_type operator()(Is... indices) const { - auto real = op_(indices...); + auto real = get_value(op_, indices...); constexpr size_t rank_idx = (Rank() == 1) ? 0 : (Rank() - 2); - auto tup = cuda::std::make_tuple(indices...); - cuda::std::get(tup) += op_.Size(rank_idx) / 2; + cuda::std::array idx{indices...}; + idx[rank_idx] += op_.Size(rank_idx) / 2; - auto imag = cuda::std::apply(op_, tup); + auto imag = get_value(op_, idx); return complex_type{real, imag}; } - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) - { - return cuda::std::as_const(*this).template operator()(indices...); - } - static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() { return detail::get_rank(); diff --git a/include/matx/operators/isclose.h b/include/matx/operators/isclose.h index 2ed404e7..fc3f6917 100644 --- a/include/matx/operators/isclose.h +++ b/include/matx/operators/isclose.h @@ -64,8 +64,8 @@ namespace matx __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ int operator()([[maybe_unused]] Is... indices) const { - return static_cast(detail::_internal_abs(op1_(indices...) - op2_(indices...)) <= - static_cast(atol_) + static_cast(rtol_) * detail::_internal_abs(op2_(indices...))); + return static_cast(detail::_internal_abs(get_value(op1_, indices...) - get_value(op2_, indices...)) <= + static_cast(atol_) + static_cast(rtol_) * detail::_internal_abs(get_value(op2_, indices...))); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() diff --git a/include/matx/operators/kronecker.h b/include/matx/operators/kronecker.h index bb03c44d..58b1a9f1 100644 --- a/include/matx/operators/kronecker.h +++ b/include/matx/operators/kronecker.h @@ -57,33 +57,29 @@ namespace matx using matxop = bool; using value_type = typename T1::value_type; - __MATX_INLINE__ std::string str() const { return "kron(" + op1_.str() + "," + op2_.str() + ")"; } + __MATX_INLINE__ std::string str() const { return "kron(" + op1_.str() + "," + op2_.str() + ")"; } __MATX_INLINE__ KronOp(const T1 &op1, const T2 &op2) : op1_(op1), op2_(op2) - { - static_assert(RankGTE(Rank(), 2), "Kronecker product must be used on tensors with rank 2 or higher"); - } + { + static_assert(RankGTE(Rank(), 2), "Kronecker product must be used on tensors with rank 2 or higher"); + } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const { - auto tup1 = cuda::std::make_tuple(indices...); - auto tup2 = cuda::std::make_tuple(indices...); - cuda::std::get(tup2) = pp_get(indices...) % op2_.Size(Rank() - 2); - cuda::std::get(tup2) = pp_get(indices...) % op2_.Size(Rank() - 1); + cuda::std::array idx1{indices...}; + cuda::std::array idx2{indices...}; - cuda::std::get(tup1) = pp_get(indices...) / op2_.Size(Rank() - 2); - cuda::std::get(tup1) = pp_get(indices...) / op2_.Size(Rank() - 1); + idx2[Rank() - 2] = pp_get(indices...) % op2_.Size(Rank() - 2); + idx2[Rank() - 1] = pp_get(indices...) % op2_.Size(Rank() - 1); - return cuda::std::apply(op2_, tup2) * cuda::std::apply(op1_, tup1); - } + idx1[Rank() - 2] = pp_get(indices...) / op2_.Size(Rank() - 2); + idx1[Rank() - 1] = pp_get(indices...) / op2_.Size(Rank() - 1); - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) - { - return cuda::std::as_const(*this).template operator()(indices...); + return get_value(op2_, idx2) * get_value(op1_, idx1); } + template __MATX_INLINE__ void PreRun(ShapeType &&shape, Executor &&ex) const noexcept { diff --git a/include/matx/operators/legendre.h b/include/matx/operators/legendre.h index 73486c78..abfd614f 100644 --- a/include/matx/operators/legendre.h +++ b/include/matx/operators/legendre.h @@ -100,49 +100,49 @@ namespace matx } template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ value_type operator()(Is... indices) const - { - cuda::std::array inds{indices...}; - cuda::std::array xinds{}; - - int axis1 = axis_[0]; - int axis2 = axis_[1]; - - // compute n - index_t nind = inds[axis1]; - int n = get_value(n_, nind); - - // compute m - index_t mind = inds[axis2]; - int m = get_value(m_, mind); - - if(axis1>axis2) - cuda::std::swap(axis1, axis2); - - // compute indices for x - int idx = 0; - for(int i = 0; i < Rank(); i++) { - index_t ind = inds[i]; - if(i != axis1 && i != axis2) { - xinds[idx++] = ind; - } + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ value_type operator()(Is... indices) const + { + cuda::std::array inds{indices...}; + cuda::std::array xinds{}; + + int axis1 = axis_[0]; + int axis2 = axis_[1]; + + // compute n + index_t nind = inds[axis1]; + int n = get_value(n_, nind); + + // compute m + index_t mind = inds[axis2]; + int m = get_value(m_, mind); + + if(axis1>axis2) + cuda::std::swap(axis1, axis2); + + // compute indices for x + int idx = 0; + for(int i = 0; i < Rank(); i++) { + index_t ind = inds[i]; + if(i != axis1 && i != axis2) { + xinds[idx++] = ind; } + } - auto x = cuda::std::apply(in_, xinds); + auto x = get_value(in_, xinds); - value_type ret; + value_type ret; - // if we are half precision up cast to float - if constexpr (is_complex_half_v) { - ret = static_cast(legendre(n, m, cuda::std::complex(x))); - } else if constexpr (is_matx_half_v) { - ret = static_cast(legendre(n, m, float(x))); - } else { - ret = legendre(n, m, x); - } - - return ret; - } + // if we are half precision up cast to float + if constexpr (is_complex_half_v) { + ret = static_cast(legendre(n, m, cuda::std::complex(x))); + } else if constexpr (is_matx_half_v) { + ret = static_cast(legendre(n, m, float(x))); + } else { + ret = legendre(n, m, x); + } + + return ret; + } template __MATX_INLINE__ void PreRun(ShapeType &&shape, Executor &&ex) const noexcept diff --git a/include/matx/operators/overlap.h b/include/matx/operators/overlap.h index 5717db22..9face054 100644 --- a/include/matx/operators/overlap.h +++ b/include/matx/operators/overlap.h @@ -86,14 +86,20 @@ namespace matx s_[0] = stride_size; }; + template + static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(Op&& op, index_t i0) + { + return get_value(cuda::std::forward(op), i0); + } + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(index_t i0, index_t i1) const { - return op_(i0*s_[0] + i1); + return get_impl(cuda::std::as_const(op_), i0*s_[0] + i1); } __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(index_t i0, index_t i1) { - return op_(i0*s_[0] + i1); + return get_impl(cuda::std::forward(op_), i0*s_[0] + i1); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() diff --git a/include/matx/operators/permute.h b/include/matx/operators/permute.h index bf366e24..cebb62dc 100644 --- a/include/matx/operators/permute.h +++ b/include/matx/operators/permute.h @@ -74,11 +74,10 @@ namespace matx dims_[i] = dims[i]; } - }; - + } - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const + template + static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(Op&& op, const Dims &dims, Is... indices) { static_assert(sizeof...(Is)==Rank()); static_assert((std::is_convertible_v && ... )); @@ -88,7 +87,6 @@ namespace matx IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized") cuda::std::array ind; IGNORE_WARNING_POP_GCC - //cuda::std::array ind{indices...}; #if 0 //This causes register spills but might be faster if Rank is large @@ -102,20 +100,25 @@ IGNORE_WARNING_POP_GCC for(int32_t i = 0; i < Rank(); i++) { #pragma unroll for(int32_t j = 0; j < Rank(); j++) { - if(dims_[j] == i) { + if(dims[j] == i) { ind[i] = inds[j]; } } - } -#endif + } +#endif + return get_value(cuda::std::forward(op), ind); + } - return cuda::std::apply(op_, ind); + template + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const + { + return get_impl(cuda::std::as_const(op_), dims_, indices...); } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) { - return cuda::std::as_const(*this).template operator()(indices...); + return get_impl(cuda::std::forward(op_), dims_, indices...); } constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int32_t dim) const diff --git a/include/matx/operators/planar.h b/include/matx/operators/planar.h index 4bb16cd4..34dbd486 100644 --- a/include/matx/operators/planar.h +++ b/include/matx/operators/planar.h @@ -60,19 +60,14 @@ namespace matx __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const { constexpr size_t rank_idx = (Rank() == 1) ? 0 : (Rank() - 2); - auto tup = cuda::std::make_tuple(indices...); - if (cuda::std::get(tup) >= op_.Size(rank_idx)) { - cuda::std::get(tup) -= op_.Size(rank_idx); - return cuda::std::apply(op_, tup).imag(); - } + cuda::std::array idx{indices...}; - return op_(indices...).real(); - } + if (idx[rank_idx] >= op_.Size(rank_idx)) { + idx[rank_idx] -= op_.Size(rank_idx); + return get_value(op_, idx).imag(); + } - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) - { - return cuda::std::as_const(*this).template operator()(indices...); + return get_value(op_, indices...).real(); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() diff --git a/include/matx/operators/polyval.h b/include/matx/operators/polyval.h index ef10b806..4769ca65 100644 --- a/include/matx/operators/polyval.h +++ b/include/matx/operators/polyval.h @@ -65,7 +65,7 @@ namespace matx // Horner's method for computing polynomial value_type ttl{coeffs_(0)}; for(int i = 1; i < coeffs_.Size(0); i++) { - ttl = ttl * op_(idx) + coeffs_(i); + ttl = ttl * get_value(op_, idx) + coeffs_(i); } return ttl; diff --git a/include/matx/operators/r2c.h b/include/matx/operators/r2c.h index 59d268b6..dce5e339 100644 --- a/include/matx/operators/r2c.h +++ b/include/matx/operators/r2c.h @@ -56,27 +56,19 @@ namespace matx static_assert(Rank() >= 1, "R2COp must have a rank 1 operator or higher"); }; - // This version of the operator returns auto rather than decltype(auto) because we need to force the - // return type to be by value and not pass through references template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const - { - auto tup = cuda::std::make_tuple(indices...); + { + cuda::std::array idx{indices...}; // If we're on the upper part of the spectrum, return the conjugate of the first half - if (cuda::std::get(tup) >= op_.Size(Rank()-1)) { - cuda::std::get(tup) = orig_size_ - cuda::std::get(tup); - return conj(cuda::std::apply(op_, tup)); + if (idx[Rank() - 1] >= op_.Size(Rank()-1)) { + idx[Rank() - 1] = orig_size_ - idx[Rank() - 1]; + return conj(get_value(op_, idx)); } - return cuda::std::apply(op_, tup); - } - - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) - { - return cuda::std::as_const(*this).template operator()(indices...); - } + return get_value(op_, idx); + } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() { diff --git a/include/matx/operators/remap.h b/include/matx/operators/remap.h index 5be669df..7b37a053 100644 --- a/include/matx/operators/remap.h +++ b/include/matx/operators/remap.h @@ -64,31 +64,36 @@ namespace matx __MATX_INLINE__ std::string str() const { return "remap(" + op_.str() + ")"; } - __MATX_INLINE__ RemapOp(const T &op, IdxType idx) : op_(op), idx_(idx) {}; + __MATX_INLINE__ RemapOp(const T &op, IdxType idx) : op_(op), idx_(idx) {}; - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const + template + static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(Op&& op, const Idx &idx, Is... indices) { - static_assert(sizeof...(Is)==Rank()); + static_assert(sizeof...(Is) == Rank()); static_assert((std::is_convertible_v && ... )); - // convert variadic type to tuple so we can read/update - cuda::std::array ind{indices...}; + cuda::std::array ind{indices...}; // remap current index for dim if constexpr (IdxType::Rank() == 0) { - ind[DIM] = idx_(); + ind[DIM] = idx(); } else { - ind[DIM] = idx_(ind[DIM]); + ind[DIM] = idx(ind[DIM]); } - //return op_(ind); - return cuda::std::apply(op_, ind); + + return get_value(cuda::std::forward(op), ind); + } + + template + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const + { + return get_impl(cuda::std::as_const(op_), idx_, indices...); } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) { - return cuda::std::as_const(*this).template operator()(indices...); + return get_impl(cuda::std::forward(op_), idx_, indices...); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() diff --git a/include/matx/operators/repmat.h b/include/matx/operators/repmat.h index 87366a93..84109f2d 100644 --- a/include/matx/operators/repmat.h +++ b/include/matx/operators/repmat.h @@ -59,58 +59,60 @@ namespace matx using matxop = bool; using value_type = typename T1::value_type; - __MATX_INLINE__ std::string str() const { return "repmat(" + op_.str() + ")"; } + __MATX_INLINE__ std::string str() const { return "repmat(" + op_.str() + ")"; } - __MATX_INLINE__ RepMatOp(const T1 &op, index_t reps) : op_(op) - { - for (int dim = 0; dim < DIM; dim++) + __MATX_INLINE__ RepMatOp(const T1 &op, index_t reps) : op_(op) { - reps_[dim] = reps; + for (int dim = 0; dim < DIM; dim++) + { + reps_[dim] = reps; + } } - } - __MATX_INLINE__ RepMatOp(const T1 &op, const cuda::std::array reps) : op_(op) - { - for (int dim = 0; dim < DIM; dim++) + __MATX_INLINE__ RepMatOp(const T1 &op, const cuda::std::array reps) : op_(op) { - reps_[dim] = reps[dim]; + for (int dim = 0; dim < DIM; dim++) + { + reps_[dim] = reps[dim]; + } } - } - __MATX_INLINE__ RepMatOp(const T1 &op, const index_t *reps) : op_(op) - { - for (int dim = 0; dim < DIM; dim++) + __MATX_INLINE__ RepMatOp(const T1 &op, const index_t *reps) : op_(op) { - reps_[dim] = reps[dim]; + for (int dim = 0; dim < DIM; dim++) + { + reps_[dim] = reps[dim]; + } } - } + template + static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(Op&& op, Is... indices) + { + if constexpr (Rank() == 0) { + return op(); + } + else { + cuda::std::array idx{indices...}; - template - __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ void UpdateIndex(cuda::std::tuple &tup) const { - if constexpr (I != sizeof...(Is)) { - cuda::std::get(tup) %= op_.Size(I); - UpdateIndex(tup); + #pragma unroll + for (int i = 0; i < static_cast(idx.size()); i++) { + idx[i] %= op.Size(i); } - } + return get_value(cuda::std::forward(op), idx); + } + } + template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const { - if constexpr (Rank() == 0) { - return op_(); - } - else { - auto tup = cuda::std::make_tuple(indices...); - UpdateIndex(tup); - return cuda::std::apply(op_, tup); - } + return get_impl(cuda::std::as_const(op_), indices...); } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) { - return cuda::std::as_const(*this).template operator()(indices...); + return get_impl(cuda::std::forward(op_), indices...); } template diff --git a/include/matx/operators/reshape.h b/include/matx/operators/reshape.h index 8df543f2..0ff2de6d 100644 --- a/include/matx/operators/reshape.h +++ b/include/matx/operators/reshape.h @@ -79,9 +79,9 @@ namespace matx MATX_ASSERT_STR(size == TotalSize(op_), matxInvalidSize, "ReshapeOp: TotalSize of reshape must match"); }; - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const - { + template + static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(Op&& op, const decltype(sizes_) &sizes, Is... indices) + { cuda::std::array inds{indices...}; cuda::std::array ninds; @@ -92,23 +92,29 @@ namespace matx #pragma unroll for(int i = Rank() - 1 ; i >= 0 ; i--) { idx += stride * inds[i]; - stride *= Size(i); + stride *= sizes[i]; } // extract new indices #pragma unroll for(int i = T::Rank() - 1; i >= 0; i--) { - ninds[i] = idx % op_.Size(i); - idx /= op_.Size(i); - } + ninds[i] = idx % op.Size(i); + idx /= op.Size(i); + } - return cuda::std::apply(op_, ninds); + return get_value(cuda::std::forward(op), ninds); + } + + template + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const + { + return get_impl(cuda::std::as_const(op_), sizes_, indices...); } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) { - return cuda::std::as_const(*this).template operator()(indices...); + return get_impl(cuda::std::forward(op_), sizes_, indices...); } constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int32_t dim) const diff --git a/include/matx/operators/reverse.h b/include/matx/operators/reverse.h index 0964de47..cf7dd4fa 100644 --- a/include/matx/operators/reverse.h +++ b/include/matx/operators/reverse.h @@ -63,24 +63,29 @@ namespace matx __MATX_INLINE__ ReverseOp(const T1 &op) : op_(op){}; - - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const + template + static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(Op&& op, Is... indices) { if constexpr (Rank() == 0) { - return op_(); + return op(); } else { - auto tup = cuda::std::make_tuple(indices...); - cuda::std::get(tup) = Size(DIM) - cuda::std::get(tup) - 1; - return cuda::std::apply(op_, tup); - } + cuda::std::array idx{indices...}; + idx[DIM] = op.Size(DIM) - idx[DIM] - 1; + return get_value(cuda::std::forward(op), idx); + } + } + + template + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const + { + return get_impl(cuda::std::as_const(op_), indices...); } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) { - return cuda::std::as_const(*this).template operator()(indices...); + return get_impl(cuda::std::forward(op_), indices...); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() diff --git a/include/matx/operators/select.h b/include/matx/operators/select.h index 2ce10a1d..5b85ef05 100644 --- a/include/matx/operators/select.h +++ b/include/matx/operators/select.h @@ -59,17 +59,23 @@ namespace matx __MATX_INLINE__ SelectOp(const T &op, IdxType idx) : op_(op), idx_(idx) {}; + template + static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(Op&& op, const Idx &idx, index_t i) + { + auto arrs = detail::GetIdxFromAbs(op, idx(i)); + return get_value(op, arrs); + } + template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(index_t i) const { - auto arrs = detail::GetIdxFromAbs(op_, idx_(i)); - return op_(arrs); + return get_impl(cuda::std::as_const(op_), idx_, i); } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(index_t i) { - return cuda::std::as_const(*this).template operator()(i); + return get_impl(cuda::std::forward(op_), idx_, i); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() diff --git a/include/matx/operators/shift.h b/include/matx/operators/shift.h index a8ccaac5..4685e9e0 100644 --- a/include/matx/operators/shift.h +++ b/include/matx/operators/shift.h @@ -53,10 +53,6 @@ namespace matx template class ShiftOp : public BaseOp> { - private: - typename detail::base_type_t op_; - T2 shift_; - public: using matxop = bool; using matxoplvalue = bool; @@ -68,30 +64,49 @@ namespace matx __MATX_INLINE__ ShiftOp(const T1 &op, T2 shift) : op_(op), shift_(shift) { static_assert(DIM < Rank(), "Dimension to shift must be less than rank of tensor"); + + #pragma unroll + for (int i = 0; i < Rank(); i++) { + index_t size1 = detail::get_expanded_size(op_, i); + index_t size2 = detail::get_expanded_size(shift_, i); + sizes_[i] = detail::matx_max(size1,size2); + } + ASSERT_COMPATIBLE_OP_SIZES(shift_); - ASSERT_COMPATIBLE_OP_SIZES(op_); + ASSERT_COMPATIBLE_OP_SIZES(op_); } - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const - { - auto tup = cuda::std::make_tuple(indices...); - index_t shift = -get_value(shift_, indices...); - + template + static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl( + Op&& op, + const Sizes &sizes, + ShiftType shiftin, + Is... indices) + { + cuda::std::array idx{indices...}; + index_t shift = -get_value(shiftin, indices...); - shift = (shift + cuda::std::get(tup)) % Size(DIM); + shift = (shift + idx[DIM]) % sizes[DIM]; - if(shift<0) shift += Size(DIM); + if (shift < 0) { + shift += sizes[DIM]; + } - cuda::std::get(tup) = shift; + idx[DIM] = shift; - return cuda::std::apply(op_, tup); + return get_value(cuda::std::forward(op), idx); + } + + template + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const + { + return get_impl(cuda::std::as_const(op_), sizes_, shift_, indices...); } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) { - return cuda::std::as_const(*this).template operator()(indices...); + return get_impl(cuda::std::forward(op_), sizes_, shift_, indices...); } template @@ -117,9 +132,7 @@ namespace matx constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto Size(int dim) const noexcept { - index_t size1 = detail::get_expanded_size(op_, dim); - index_t size2 = detail::get_expanded_size(shift_, dim); - return detail::matx_max(size1,size2); + return sizes_[dim]; } ~ShiftOp() = default; @@ -137,6 +150,11 @@ namespace matx return set(*this, rhs); } } + + private: + typename detail::base_type_t op_; + cuda::std::array sizes_; + typename detail::base_type_t shift_; }; } /** diff --git a/include/matx/operators/slice.h b/include/matx/operators/slice.h index 4453608c..9c4402a5 100644 --- a/include/matx/operators/slice.h +++ b/include/matx/operators/slice.h @@ -109,38 +109,49 @@ namespace matx MATX_ASSERT_STR(d==Rank(), matxInvalidDim, "SliceOp: Number of dimensions without matxDropDim must equal new rank."); }; - template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const - { + template + static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl( + Op&& op, + const decltype(starts_) &starts, + const decltype(strides_) &strides, + const decltype(dims_) &dims, + Is... indices) + { static_assert(sizeof...(Is)==Rank()); static_assert((std::is_convertible_v && ... )); // convert variadic type to tuple so we can read/update - cuda::std::array ind = starts_; + cuda::std::array ind = starts; cuda::std::array inds{indices...}; #pragma unroll for (int32_t i = 0; i < T::Rank(); i++) { #pragma unroll for(int32_t j = 0; j < Rank(); j++) { - if(dims_[j] == i) { + if(dims[j] == i) { if constexpr (!std::is_same_v) { - ind[i] = starts_[j] + inds[j] * strides_[i]; + ind[i] = starts[j] + inds[j] * strides[i]; } else { - ind[i] = starts_[j] + inds[j]; + ind[i] = starts[j] + inds[j]; } } } } - return cuda::std::apply(op_, ind); + return get_value(cuda::std::forward(op), ind); + } + + template + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const + { + return get_impl(cuda::std::as_const(op_), starts_, strides_, dims_, indices...); } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) { - return cuda::std::as_const(*this).template operator()(indices...); + return get_impl(cuda::std::forward(op_), starts_, strides_, dims_, indices...); } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() diff --git a/include/matx/operators/stack.h b/include/matx/operators/stack.h index 2d723975..3353aeb3 100644 --- a/include/matx/operators/stack.h +++ b/include/matx/operators/stack.h @@ -86,42 +86,42 @@ namespace matx } template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto GetVal(index_t oidx, cuda::std::array &indices) const { + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto GetVal(index_t oidx, cuda::std::array &indices) const { - if constexpr ( I == N ) { - // This should never happen - return value_type(-9999); + if constexpr ( I == N ) { + // This should never happen + return value_type(-9999); + } else { + if ( I < oidx ) { + // this is not the correct operator, recurse + return GetVal(oidx, indices); } else { - if ( I < oidx ) { - // this is not the correct operator, recurse - return GetVal(oidx, indices); - } else { - // this is the correct operator, return it's value - auto &op = cuda::std::get(ops_); - return cuda::std::apply(op, indices); - } + // this is the correct operator, return it's value + auto &op = cuda::std::get(ops_); + return get_value(op, indices); } } + } template - __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto& GetVal(index_t oidx, cuda::std::array &indices) { + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto& GetVal(index_t oidx, cuda::std::array &indices) { - if constexpr ( I == N ) { - // This should never happen - auto &op = cuda::std::get(ops_); - return cuda::std::apply(op, indices); + if constexpr ( I == N ) { + // This should never happen + auto &op = cuda::std::get(ops_); + return get_value(op, indices); + } else { + if ( I < oidx ) { + // this is not the correct operator, recurse + return GetVal(oidx, indices); } else { - if ( I < oidx ) { - // this is not the correct operator, recurse - return GetVal(oidx, indices); - } else { - // this is the correct operator, return it's value - auto &op = cuda::std::get(ops_); - return cuda::std::apply(op, indices); - } + // this is the correct operator, return it's value + auto &op = cuda::std::get(ops_); + return get_value(op, indices); } } + } template __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... is) const diff --git a/include/matx/operators/toeplitz.h b/include/matx/operators/toeplitz.h index 5527be87..342e75f2 100644 --- a/include/matx/operators/toeplitz.h +++ b/include/matx/operators/toeplitz.h @@ -94,7 +94,7 @@ namespace matx { if (j > i) { if constexpr (is_matx_op()) { - auto val = op2_(j - i); + auto val = get_value(op2_, j - i); return val; } else { @@ -104,7 +104,7 @@ namespace matx } else { if constexpr (is_matx_op()) { - auto val = op1_(i - j); + auto val = get_value(op1_, i - j); return val; } else { diff --git a/include/matx/operators/updownsample.h b/include/matx/operators/updownsample.h index 4194a9e0..26d041b4 100644 --- a/include/matx/operators/updownsample.h +++ b/include/matx/operators/updownsample.h @@ -78,7 +78,7 @@ namespace matx cuda::std::array ind{indices...}; if ((ind[dim_] % n_) == 0) { ind[dim_] /= n_; - return cuda::std::apply(op_, ind); + return get_value(op_, ind); } return static_cast(0);