Skip to content

Commit

Permalink
Fixing broadcasting in all operator() (#795)
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick authored Nov 23, 2024
1 parent c33d749 commit 3368402
Show file tree
Hide file tree
Showing 29 changed files with 427 additions and 320 deletions.
60 changes: 50 additions & 10 deletions include/matx/core/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#pragma once

#include <cuda/std/tuple>
#include <cuda/std/__algorithm/copy.h>
#include <functional>

#include "matx/core/nvtx.h"
Expand Down Expand Up @@ -245,38 +246,77 @@ namespace matx
* @param indices indices
* @return Value after broadcasting
*/
template <class T, typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto get_matx_value(const T &i, Is... indices)
template <typename T, typename... Is, std::enable_if_t<std::conjunction_v<std::is_integral<Is>...>, bool> = true>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_matx_value(T &&i, Is... indices)
{
if constexpr (T::Rank() == int(sizeof...(Is)) || T::Rank() == matxNoRank) {
return i(indices...);
constexpr int RANK = remove_cvref_t<T>::Rank();
if constexpr (RANK == int(sizeof...(Is)) || RANK == matxNoRank) {
// If we're only indexing with the same number of arguments as the rank of the operator, just return operator()
return cuda::std::forward<T>(i)(indices...);
}
else
{
// Construct an integer sequence of the length of the tuple, but only using the last indices
using seq = offset_sequence_t<sizeof...(Is) - T::Rank(), std::make_index_sequence<T::Rank()>>;
// Otherwise we need to broadcast by constructing a large set of indices
// Construct an integer sequence of the length of the tuple, but only using the last indices. We construct an offset sequence
// to index into the broadcasted dimensions. For example, if T is a 3D tensor and we want to index as a 5D, we take the indices
// {0, 1, 2} we'd normally index with, and add the difference in rank (2), to get {2, 3, 4}. Another way to think of this is it
// simply chops off the first sizeof...(Is) - RANK indices since they're not used for operator().
using seq = offset_sequence_t<sizeof...(Is) - RANK, std::make_index_sequence<RANK>>;
auto tup = cuda::std::make_tuple(indices...);
auto sliced_tup = select_tuple(std::forward<decltype(tup)>(tup), seq{});
return cuda::std::apply([&](auto... args) {
return i(args...);
return cuda::std::forward<T>(i)(args...);
}, sliced_tup);
}
}

template <typename T, typename IdxType, size_t N>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_matx_value(T &&i, const cuda::std::array<IdxType, N> idx)
{
constexpr int RANK = remove_cvref_t<T>::Rank();
if constexpr (RANK == N || RANK == matxNoRank) {
// If we're only indexing with the same number of arguments as the rank of the operator, just return operator()
return cuda::std::apply(cuda::std::forward<T>(i), idx);
//return i(indices...);
}
else
{
cuda::std::array<index_t, RANK> nbc_idx; // non-broadcast indices
cuda::std::copy(idx.begin() + (N - RANK), idx.end(), nbc_idx.begin());
return cuda::std::apply([&](auto... args) {
return cuda::std::forward<T>(i)(args...);
}, nbc_idx);
}
}


template <class T, typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto get_value(const T &i, Is... indices)
template <typename T, typename... Is, std::enable_if_t<std::conjunction_v<std::is_integral<Is>...>, bool> = true>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_value(T &&i, Is... indices)
{
if constexpr (is_matx_op<T>())
{
return get_matx_value(i, indices...);
return get_matx_value(cuda::std::forward<T>(i), indices...);
}
else
{
return i;
}
}


template <typename T, typename IdxType, size_t N>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_value(T &&i, const cuda::std::array<IdxType, N> idx)
{
if constexpr (is_matx_op<T>())
{
return get_matx_value(cuda::std::forward<T>(i), idx);
}
else
{
return i;
}
}

template <typename T> __MATX_INLINE__ std::string to_short_str() {
if constexpr (!is_complex_v<T>) {
if constexpr (std::is_same_v<T, bool>)
Expand Down
2 changes: 1 addition & 1 deletion include/matx/operators/at.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace matx
template <typename... Is2>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()([[maybe_unused]] Is2... indices) const
{
return cuda::std::apply(op_, idx_);
return get_value(op_, idx_);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
Expand Down
4 changes: 2 additions & 2 deletions include/matx/operators/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ namespace matx
template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const
{
return static_cast<NewType>(op_(indices...));
return static_cast<NewType>(get_value(op_, indices...));
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices)
{
return static_cast<NewType>(op_(indices...));
return static_cast<NewType>(get_value(op_, indices...));
}

template <typename ShapeType, typename Executor>
Expand Down
20 changes: 12 additions & 8 deletions include/matx/operators/clone.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,29 +85,33 @@ IGNORE_WARNING_POP_GCC

}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const
{

// convert variadic type to tuple so we can read/update
template <typename Op, typename Dims, typename... Is>
static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(Op&& op, const Dims &dims, Is... indices)
{
IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized")
cuda::std::array<index_t, Rank()> sind{indices...};
cuda::std::array<index_t, T::Rank()> gind;
IGNORE_WARNING_POP_GCC

// gather indices
for(int i = 0; i < T::Rank(); i++) {
auto idx = dims_[i];
auto idx = dims[i];
gind[i] = sind[idx];
}

return cuda::std::apply(op_, gind);
return get_value(cuda::std::forward<Op>(op), gind);
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const
{
return get_impl(cuda::std::as_const(op_), dims_, indices...);
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices)
{
return cuda::std::as_const(*this).template operator()(indices...);
return get_impl(cuda::std::forward<decltype(op_)>(op_), dims_, indices...);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
Expand Down
40 changes: 26 additions & 14 deletions include/matx/operators/collapse.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ namespace matx
}
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const
template <typename Op, typename... Is>
static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(Op&& op, Is... indices)
{
// indices coming in
cuda::std::array<index_t, Rank()> in{indices...}; // index coming in
Expand All @@ -88,17 +88,23 @@ namespace matx
#pragma unroll
for(int i = 0; i < DIM; i++) {
int d = DIM - i - 1;
out[d] = ind % op_.Size(d);
ind /= op_.Size(d);
out[d] = ind % op.Size(d);
ind /= op.Size(d);
}

return cuda::std::apply(op_, out);
return get_value(cuda::std::forward<Op>(op), out);
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const
{
return get_impl(cuda::std::as_const(op_), indices...);
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices)
{
return cuda::std::as_const(*this).template operator()(indices...);
return get_impl(cuda::std::forward<decltype(op_)>(op_), indices...);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
Expand Down Expand Up @@ -207,9 +213,9 @@ namespace matx
}
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const
{
template <typename Op, typename... Is>
static __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) get_impl(Op&& op, Is... indices)
{
// indices coming in
cuda::std::array<index_t, Rank()> in{indices...}; // index coming in
cuda::std::array<index_t, T1::Rank()> out; // index going out
Expand All @@ -225,18 +231,24 @@ namespace matx
#pragma unroll
for(int i = 0; i < DIM; i++) {
int d = T1::Rank() - 1 - i;
out[d] = ind % op_.Size(d);
ind /= op_.Size(d);
out[d] = ind % op.Size(d);
ind /= op.Size(d);
}

return cuda::std::apply(op_, out);
return get_value(cuda::std::forward<Op>(op), out);
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const
{
return get_impl(cuda::std::as_const(op_), indices...);
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices)
{
return cuda::std::as_const(*this).template operator()(indices...);
}
return get_impl(cuda::std::forward<decltype(op_)>(op_), indices...);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
Expand Down
8 changes: 4 additions & 4 deletions include/matx/operators/comma.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ namespace matx
__MATX_INLINE__ std::string str() const { return op1_.str() + ", " + op2_.str(); }

template <typename... Is>
auto __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ operator()(Is... indices) const {
op1_(indices...);
return op2_(indices...);
}
auto __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ operator()(Is... indices) const {
get_value(op1_, indices...);
return get_value(op2_, indices...);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() noexcept
{
Expand Down
75 changes: 39 additions & 36 deletions include/matx/operators/concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,52 +91,55 @@ namespace matx
}
}




template <int I = 0, int N>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto GetVal(cuda::std::array<index_t,RANK> &indices) const {
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto GetVal(cuda::std::array<index_t,RANK> &indices) const {

if constexpr ( I == N ) {
// This should never happen
return value_type{};
// returning this to satisfy lvalue requirements
if constexpr ( I == N ) {
// This should never happen
return value_type{};
// returning this to satisfy lvalue requirements
} else {
const auto &op = cuda::std::get<I>(ops_);
auto idx = indices[axis_];
auto size = op.Size(axis_);
// If in range of this operator
if(idx < size) {
// evaluate operator
return get_value(cuda::std::forward<decltype(op)>(op), indices);
} else {
const auto &op = cuda::std::get<I>(ops_);
auto idx = indices[axis_];
auto size = op.Size(axis_);
// If in range of this operator
if(idx < size) {
// evaluate operator
return cuda::std::apply(op, indices);
} else {
// otherwise remove this operator and recurse
indices[axis_] -= size;
return GetVal<I+1, N>(indices);
}
// otherwise remove this operator and recurse
indices[axis_] -= size;
return GetVal<I+1, N>(indices);
}
}

}

template <int I = 0, int N>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) GetVal(cuda::std::array<index_t,RANK> &indices) {
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) GetVal(cuda::std::array<index_t,RANK> &indices) {

if constexpr ( I == N ) {
// This should never happen
// returning this to satisfy lvalue requirements
auto &op = cuda::std::get<I-1>(ops_);
return cuda::std::apply(op, indices);
if constexpr ( I == N ) {
// This should never happen
// returning this to satisfy lvalue requirements
auto &op = cuda::std::get<I-1>(ops_);
return get_value(cuda::std::forward<decltype(op)>(op), indices);
} else {
auto &op = cuda::std::get<I>(ops_);
auto idx = indices[axis_];
auto size = op.Size(axis_);
// If in range of this operator
if(idx < size) {
// evaluate operator
return get_value(cuda::std::forward<decltype(op)>(op), indices);
} else {
auto &op = cuda::std::get<I>(ops_);
auto idx = indices[axis_];
auto size = op.Size(axis_);
// If in range of this operator
if(idx < size) {
// evaluate operator
return cuda::std::apply(op, indices);
} else {
// otherwise remove this operator and recurse
indices[axis_] -= size;
return GetVal<I+1, N>(indices);
}
// otherwise remove this operator and recurse
indices[axis_] -= size;
return GetVal<I+1, N>(indices);
}
}
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... is) const
Expand Down
18 changes: 11 additions & 7 deletions include/matx/operators/diag.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,23 @@ namespace matx

// Offset either the rows or columns by k_, depending on if it's negative
if (k_ < 0) {
auto tup = cuda::std::make_tuple(indices..., static_cast<tt>(0));
cuda::std::get<RANK - 1>(tup) = pp_get<RANK-2>(indices...) ;
cuda::std::array<tt, sizeof...(Is) + 1> tmp{indices...};
tmp[RANK - 1] = pp_get<RANK-2>(indices...);
//cuda::std::get<RANK - 1>(tup) = pp_get<RANK-2>(indices...) ;
IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized")
cuda::std::get<RANK - 2>(tup) = cuda::std::get<RANK - 2>(tup) - k_;
tmp[RANK - 2] -= k_;
//cuda::std::get<RANK - 2>(tup) = cuda::std::get<RANK - 2>(tup) - k_;
IGNORE_WARNING_POP_GCC
return cuda::std::apply(op_, tup);
return get_value(op_, tmp);
}
else {
auto tup = cuda::std::make_tuple(indices..., static_cast<tt>(0));
cuda::std::array<tt, sizeof...(Is) + 1> tmp{indices...};
//auto tup = cuda::std::make_tuple(indices..., static_cast<tt>(0));
IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized")
cuda::std::get<RANK - 1>(tup) = pp_get<RANK-2>(indices...) + k_;
tmp[RANK - 1] = pp_get<RANK-2>(indices...) + k_;
//cuda::std::get<RANK - 1>(tup) = pp_get<RANK-2>(indices...) + k_;
IGNORE_WARNING_POP_GCC
return cuda::std::apply(op_, tup);
return get_value(op_, tmp);
}
}
}
Expand Down
Loading

0 comments on commit 3368402

Please sign in to comment.