Skip to content

Commit

Permalink
Refactor cub argmax to generic cub reduce, use for argmin. Fixes #774. (
Browse files Browse the repository at this point in the history
  • Loading branch information
tmartin-gh authored Oct 22, 2024
1 parent 0422b26 commit 57ef8f0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 28 deletions.
58 changes: 35 additions & 23 deletions include/matx/transforms/cub.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ typedef enum {
CUB_OP_SELECT,
CUB_OP_SELECT_IDX,
CUB_OP_UNIQUE,
CUB_OP_REDUCE_ARGMAX
CUB_OP_ARG_REDUCE,
} CUBOperation_t;

struct CubParams_t {
Expand Down Expand Up @@ -1103,6 +1103,14 @@ struct CustomArgMaxCmp
}
};

struct CustomArgMinCmp
{
template <typename T>
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ T operator()(const T &a, const T &b) const {
return thrust::get<1>(a) >= thrust::get<1>(b) ? b : a;
}
};

template <typename OutputTensor, typename TensorIndexType, typename InputOperator, typename CParams = EmptyParams_t>
class matxCubTwoOutputPlan_t {
using T1 = typename InputOperator::value_type;
Expand All @@ -1114,8 +1122,8 @@ class matxCubTwoOutputPlan_t {
#ifdef __CUDACC__
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

if (op == CUB_OP_REDUCE_ARGMAX) {
ExecArgMax(a_out, aidx_out, a, stream);
if (op == CUB_OP_ARG_REDUCE) {
ExecArgReduce(a_out, aidx_out, a, stream);
}
else {
MATX_THROW(matxNotSupported, "Invalid CUB operation");
Expand Down Expand Up @@ -1155,7 +1163,7 @@ class matxCubTwoOutputPlan_t {
}

/**
* Execute an argmax on a tensor
* Execute an arg reduce on a tensor
*
* @note Views being passed must be in row-major order
*
Expand All @@ -1175,16 +1183,14 @@ class matxCubTwoOutputPlan_t {
* CUDA stream
*
*/
inline void ExecArgMax(OutputTensor &a_out,
TensorIndexType &aidx_out,
const InputOperator &a,
const cudaStream_t stream)
inline void ExecArgReduce(OutputTensor &a_out,
TensorIndexType &aidx_out,
const InputOperator &a,
const cudaStream_t stream)
{
#ifdef __CUDACC__
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

CustomArgMaxCmp cmp_op;
const auto initial_value = cuda::std::make_tuple(static_cast<matx::index_t>(-1), std::numeric_limits<typename InputOperator::value_type>::lowest());
const auto a_iter = matx::RandomOperatorThrustIterator{a};
const auto zipped_input = thrust::make_zip_iterator(thrust::make_counting_iterator<matx::index_t>(0), a_iter);
const auto zipped_output = thrust::make_zip_iterator(aidx_out.Data(), a_out.Data());
Expand All @@ -1206,8 +1212,8 @@ class matxCubTwoOutputPlan_t {
BATCHES,
r0_iter,
r1_iter,
cmp_op,
initial_value,
cparams_.reduce_op,
cparams_.init,
stream);
}
else {
Expand All @@ -1219,8 +1225,8 @@ class matxCubTwoOutputPlan_t {
zipped_input,
zipped_output,
N,
cmp_op,
initial_value,
cparams_.reduce_op,
cparams_.init,
stream);
}
#endif
Expand Down Expand Up @@ -1539,51 +1545,57 @@ void cub_max(OutputTensor &a_out, const InputOperator &a,
}

/**
* Find argmax of an operator using CUB
* Find index and value of custom reduction of an operator using CUB
*
* @tparam OutputTensor
* Output tensor type
* @tparam TensorIndexType
* Output tensor index type
* @tparam InputOperator
* Input operator type
* @tparam CParams
* Custom reduction parameters type
* @param a_out
* Output maximum value tensor
* @param aidx_out
* Output maximum value index tensor
* @param a
* Input tensor
* @param reduce_params
* Reduction configuration parameters
* @param stream
* CUDA stream
*/
template <typename OutputTensor, typename TensorIndexType, typename InputOperator>
void cub_argmax(OutputTensor &a_out, TensorIndexType &aidx_out, const InputOperator &a,
template <typename OutputTensor, typename TensorIndexType, typename InputOperator, typename CParams>
void cub_argreduce(OutputTensor &a_out, TensorIndexType &aidx_out, const InputOperator &a, const CParams& reduce_params,
const cudaStream_t stream = 0)
{
#ifdef __CUDACC__
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
using cache_val_type = detail::matxCubTwoOutputPlan_t<OutputTensor, TensorIndexType, InputOperator>;

using cache_val_type = detail::matxCubTwoOutputPlan_t<OutputTensor, TensorIndexType, InputOperator, CParams>;

#ifndef MATX_DISABLE_CUB_CACHE
auto params = cache_val_type::GetCubParams(a_out, aidx_out, a, detail::CUB_OP_REDUCE_ARGMAX, stream);
auto params = cache_val_type::GetCubParams(a_out, aidx_out, a, detail::CUB_OP_ARG_REDUCE, stream);

detail::GetCache().LookupAndExec<detail::cub_cache_t>(
detail::GetCacheIdFromType<detail::cub_cache_t>(),
params,
[&]() {
return std::make_shared<cache_val_type>(a_out, aidx_out, a, detail::EmptyParams_t{}, stream);
return std::make_shared<cache_val_type>(a_out, aidx_out, a, reduce_params, stream);
},
[&](std::shared_ptr<cache_val_type> ctype) {
ctype->ExecArgMax(a_out, aidx_out, a, stream);
ctype->ExecArgReduce(a_out, aidx_out, a, stream);
}
);
#else
auto tmp = cache_val_type{a_out, aidx_out, a, detail::CUB_OP_REDUCE_ARGMAX, {}, stream};
tmp.ExecArgMax(a_out, aidx_out, a, stream);
auto tmp = cache_val_type{a_out, aidx_out, a, detail::CUB_OP_ARG_REDUCE, reduce_params, stream};
tmp.ExecArgReduce(a_out, aidx_out, a, stream);
#endif
#endif
}


/**
* Sort rows of a tensor
*
Expand Down
12 changes: 10 additions & 2 deletions include/matx/transforms/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -2081,8 +2081,12 @@ void __MATX_INLINE__ argmax_impl(OutType dest, TensorIndexType &idest, const InT
#ifdef __CUDACC__
MATX_NVTX_START("argmax_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API)

const auto initial_value = cuda::std::make_tuple(static_cast<matx::index_t>(-1), std::numeric_limits<typename InType::value_type>::lowest());
using reduce_param_type = typename detail::ReduceParams_t<typename detail::CustomArgMaxCmp, decltype(initial_value)>;
auto reduce_params = reduce_param_type{detail::CustomArgMaxCmp{}, initial_value};

cudaStream_t stream = exec.getStream();
cub_argmax(dest, idest, in, stream);
cub_argreduce(dest, idest, in, reduce_params, stream);
#endif
}

Expand Down Expand Up @@ -2228,8 +2232,12 @@ void __MATX_INLINE__ argmin_impl(OutType dest, TensorIndexType &idest, const InT
#ifdef __CUDACC__
MATX_NVTX_START("argmin_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API)

const auto initial_value = cuda::std::make_tuple(static_cast<matx::index_t>(-1), std::numeric_limits<typename InType::value_type>::max());
using reduce_param_type = typename detail::ReduceParams_t<typename detail::CustomArgMinCmp, decltype(initial_value)>;
auto reduce_params = reduce_param_type{detail::CustomArgMinCmp{}, initial_value};

cudaStream_t stream = exec.getStream();
reduce(dest, idest, in, detail::reduceOpMin<typename OutType::value_type>(), stream, true);
cub_argreduce(dest, idest, in, reduce_params, stream);
#endif
}

Expand Down
5 changes: 2 additions & 3 deletions test/00_operators/ReductionTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ TYPED_TEST(ReductionTestsFloatNonComplexNonHalfAllExecs, ArgMin)
EXPECT_TRUE(MatXUtils::MatXTypeCompare(t2o(rel), (TestType)(1)));
}

if (0) // disable for now since it doesn't pass
if (std::is_same_v<ExecType, matx::cudaExecutor>)
{
ExecType exec{};
const int BATCHES = 6;
Expand All @@ -1037,13 +1037,12 @@ TYPED_TEST(ReductionTestsFloatNonComplexNonHalfAllExecs, ArgMin)
expected_abs[n] += n*BATCH_STRIDE;
}

(matx::mtie(t_b, t_bi) = matx::argmax(t_a, {1,2})).run(exec);
(matx::mtie(t_b, t_bi) = matx::argmin(t_a, {1,2})).run(exec);
exec.sync();

for (int n=0; n<BATCHES; n++)
{
EXPECT_TRUE(t_bi(n) == expected_abs[n]);
std::cout << "[" << n << "] " << t_bi(n) << " ?= " << expected_abs[n] << std::endl;
}
}

Expand Down

0 comments on commit 57ef8f0

Please sign in to comment.