diff --git a/include/string_view b/include/string_view index 5c42b36ca..91fc8635e 100644 --- a/include/string_view +++ b/include/string_view @@ -758,6 +758,24 @@ hash >::operator()( return __do_string_hash(__val.data(), __val.data() + __val.size()); } +template< class _Tp, size_t _Np, class = typename enable_if< + is_character::type>::value>::type > +struct __char_array_hash +{ + using _CharT = typename remove_cv<_Tp>::type; + + size_t operator ()(const _CharT (&__chars)[_Np]) const _NOEXCEPT { + return hash>{}(__chars); + } +}; + +template< class _Tp, size_t _Np > +struct hash<_Tp[_Np]>: __char_array_hash<_Tp, _Np> +{}; + +template< class _Tp, size_t _Np > +struct hash<_Tp(&)[_Np]>: __char_array_hash<_Tp, _Np> +{}; #if _LIBCPP_STD_VER > 11 inline namespace literals diff --git a/include/type_traits b/include/type_traits index 277a9fb07..0d3086306 100644 --- a/include/type_traits +++ b/include/type_traits @@ -394,7 +394,7 @@ _LIBCPP_BEGIN_NAMESPACE_STD template struct _LIBCPP_TEMPLATE_VIS pair; template class _LIBCPP_TEMPLATE_VIS reference_wrapper; -template struct _LIBCPP_TEMPLATE_VIS hash; +template struct _LIBCPP_TEMPLATE_VIS hash; template struct __void_t { typedef void type; }; @@ -677,6 +677,19 @@ template _LIBCPP_CONSTEXPR bool is_null_pointer_v #endif #endif +// is_character + +template struct is_character : public false_type {}; +template<> struct is_character : public true_type {}; +template<> struct is_character : public true_type {}; +template<> struct is_character : public true_type {}; +template<> struct is_character : public true_type {}; + +#if _LIBCPP_STD_VER > 14 && !defined(_LIBCPP_HAS_NO_VARIABLE_TEMPLATES) +template +_LIBCPP_CONSTEXPR bool is_character_v = is_character<_Tp>::value; +#endif + // is_integral template struct __libcpp_is_integral : public false_type {}; diff --git a/include/unordered_set b/include/unordered_set index fc53c8271..9d02f6c0e 100644 --- a/include/unordered_set +++ b/include/unordered_set @@ -331,7 +331,8 @@ template _LIBCPP_BEGIN_NAMESPACE_STD -template , class _Pred = equal_to<_Value>, +template , + class _Pred = equal_to<_Value>, class _Alloc = allocator<_Value> > class _LIBCPP_TEMPLATE_VIS unordered_set { @@ -351,6 +352,11 @@ private: typedef __hash_table __table; __table __table_; + + template + using __enable_if_transparent = typename enable_if< + __is_transparent<_Hash, _K2>::value && + __is_transparent<_Pred, _K2>::value, _Rt>::type; public: typedef typename __table::pointer pointer; @@ -563,14 +569,41 @@ public: iterator find(const key_type& __k) {return __table_.find(__k);} _LIBCPP_INLINE_VISIBILITY const_iterator find(const key_type& __k) const {return __table_.find(__k);} + + template + _LIBCPP_INLINE_VISIBILITY + __enable_if_transparent<_K2, iterator> + find(const _K2& __k) {return __table_.find(__k);} + + template + _LIBCPP_INLINE_VISIBILITY + __enable_if_transparent<_K2, const_iterator> + find(const _K2& __k) const {return __table_.find(__k);} + _LIBCPP_INLINE_VISIBILITY size_type count(const key_type& __k) const {return __table_.__count_unique(__k);} + + template + _LIBCPP_INLINE_VISIBILITY + __enable_if_transparent<_K2, size_type> + count(const _K2& __k) const {return __table_.__count_unique(__k);} + _LIBCPP_INLINE_VISIBILITY pair equal_range(const key_type& __k) {return __table_.__equal_range_unique(__k);} _LIBCPP_INLINE_VISIBILITY pair equal_range(const key_type& __k) const {return __table_.__equal_range_unique(__k);} + + template + _LIBCPP_INLINE_VISIBILITY + __enable_if_transparent<_K2, pair> + equal_range(const _K2& __k) {return __table_.__equal_range_unique(__k);} + + template + _LIBCPP_INLINE_VISIBILITY + __enable_if_transparent<_K2, pair> + equal_range(const _K2& __k) const {return __table_.__equal_range_unique(__k);} _LIBCPP_INLINE_VISIBILITY size_type bucket_count() const _NOEXCEPT {return __table_.bucket_count();} diff --git a/include/utility b/include/utility index 1f41c0771..f8645cba8 100644 --- a/include/utility +++ b/include/utility @@ -1588,6 +1588,19 @@ using __enable_hash_helper = _Type; #endif // !_LIBCPP_CXX03_LANG +template <> +struct _LIBCPP_TEMPLATE_VIS hash +{ + template + _LIBCPP_INLINE_VISIBILITY + size_t operator()(T&& __v) const _NOEXCEPT { + return hash::type>::type>{}(__v); + } + + typedef void is_transparent; +}; + _LIBCPP_END_NAMESPACE_STD #endif // _LIBCPP_UTILITY diff --git a/test/std/containers/unord/unord.set/count.pass.cpp b/test/std/containers/unord/unord.set/count.pass.cpp index 18cac7cf9..f306f90de 100644 --- a/test/std/containers/unord/unord.set/count.pass.cpp +++ b/test/std/containers/unord/unord.set/count.pass.cpp @@ -16,6 +16,7 @@ // size_type count(const key_type& k) const; #include +#include #include #include "min_allocator.h" @@ -65,5 +66,17 @@ int main() assert(c.count(50) == 1); assert(c.count(5) == 0); } + { + std::unordered_set, std::equal_to<>> const s = {"one", "two", "three", "four"}; + std::string_view str_v = "three"; + assert(s.count(str_v) == 1); + assert(s.count("three") == 1); + char c1[] = "three"; + char c2[] = {'t','h','r','e','e','\0'}; + assert(s.count(c1) == 1); + assert(s.count(c2) == 1); + assert(s.count(std::string_view{"two"}) == 1); + assert(s.count(std::string_view{"TWO"}) == 0); + } #endif } diff --git a/test/std/containers/unord/unord.set/find_const.pass.cpp b/test/std/containers/unord/unord.set/find_const.pass.cpp index bd4542c87..af055d62f 100644 --- a/test/std/containers/unord/unord.set/find_const.pass.cpp +++ b/test/std/containers/unord/unord.set/find_const.pass.cpp @@ -16,10 +16,39 @@ // const_iterator find(const key_type& k) const; #include +#include #include #include "min_allocator.h" +struct dummy_int +{ + static size_t counter; + int value; + + dummy_int(): value{0} { ++counter; } + dummy_int(int v): value{v} { ++counter; } + dummy_int(dummy_int const& other): value(other.value) { ++counter; } + dummy_int(dummy_int&& other): value(other.value) { ++counter; } + + operator int () const { return value; } +}; + +size_t dummy_int::counter = 0; + +namespace std +{ + +template<> +struct hash +{ + size_t operator ()(dummy_int const& d) const { + return std::hash{}(d.value); + } +}; + +} + int main() { { @@ -63,5 +92,34 @@ int main() i = c.find(5); assert(i == c.cend()); } + { + dummy_int a[] = {10}; // +1 + + std::unordered_set const c1{std::begin(a), std::end(a)}; // +1 + assert(dummy_int::counter == 2); + auto i1 = c1.find(10); // +1 + assert(i1 != c1.end()); + assert(dummy_int::counter == 3); + + std::unordered_set, std::equal_to<>> const c2{std::begin(a), std::end(a)}; // +1 + assert(dummy_int::counter == 4); + auto i2 = c2.find(10); // +0 + assert(i2 != c2.end()); + assert(dummy_int::counter == 4); + } + { + std::unordered_set, std::equal_to<>> const s = {"one", "two", "three", "four"}; + std::string_view str_v = "three"; + std::string str3 = "three"; + std::string str2 = "two"; + assert(s.find(str_v) == s.find(str3)); + assert(s.find("three") == s.find(str3)); + char c1[] = "three"; + char c2[] = {'t','h','r','e','e','\0'}; + assert(s.find(c1) == s.find(str3)); + assert(s.find(c2) == s.find(str3)); + assert(s.find(std::string_view{"two"}) == s.find(str2)); + assert(s.find(std::string_view{"TWO"}) == s.end()); + } #endif }