diff --git a/clang/runtime/dpct-rt/include/dpct/dpl_extras/functional.h b/clang/runtime/dpct-rt/include/dpct/dpl_extras/functional.h index bab82814c210..72799cc2bdad 100644 --- a/clang/runtime/dpct-rt/include/dpct/dpl_extras/functional.h +++ b/clang/runtime/dpct-rt/include/dpct/dpl_extras/functional.h @@ -58,6 +58,40 @@ template struct mark_functor_const { } }; +// Forward declare key_value_pair to avoid creating cyclic dependency between +// iterators.h and functional.h. +template class key_value_pair; + +// Returns the smaller of two key_value_pair objects based on their value +// member. If value elements compare equal, then the pair with the lower key is +// returned. +struct argmin { + template + key_value_pair<_KeyTp, _ValueTp> + operator()(const key_value_pair<_KeyTp, _ValueTp> &lhs, + const key_value_pair<_KeyTp, _ValueTp> &rhs) const { + return (lhs.value < rhs.value) || + (lhs.value == rhs.value && lhs.key < rhs.key) + ? lhs + : rhs; + } +}; + +// Returns the larger of two key_value_pair objects based on their value member. +// If value elements compare equal, then the pair with the lower key is +// returned. +struct argmax { + template + key_value_pair<_KeyTp, _ValueTp> + operator()(const key_value_pair<_KeyTp, _ValueTp> &lhs, + const key_value_pair<_KeyTp, _ValueTp> &rhs) const { + return (lhs.value > rhs.value) || + (lhs.value == rhs.value && lhs.key < rhs.key) + ? lhs + : rhs; + } +}; + namespace internal { template