Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/unord heterog lookup #1

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/string_view
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,24 @@ hash<basic_string_view<_CharT, _Traits> >::operator()(
return __do_string_hash(__val.data(), __val.data() + __val.size());
}

template< class _Tp, size_t _Np, class = typename enable_if<
is_character<typename remove_cv<_Tp>::type>::value>::type >
struct __char_array_hash
{
using _CharT = typename remove_cv<_Tp>::type;

size_t operator ()(const _CharT (&__chars)[_Np]) const _NOEXCEPT {
return hash<basic_string_view<_CharT>>{}(__chars);
}
};

template< class _Tp, size_t _Np >
struct hash<_Tp[_Np]>: __char_array_hash<_Tp, _Np>
{};

template< class _Tp, size_t _Np >
struct hash<_Tp(&)[_Np]>: __char_array_hash<_Tp, _Np>
{};

#if _LIBCPP_STD_VER > 11
inline namespace literals
Expand Down
15 changes: 14 additions & 1 deletion include/type_traits
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ _LIBCPP_BEGIN_NAMESPACE_STD

template <class _T1, class _T2> struct _LIBCPP_TEMPLATE_VIS pair;
template <class _Tp> class _LIBCPP_TEMPLATE_VIS reference_wrapper;
template <class _Tp> struct _LIBCPP_TEMPLATE_VIS hash;
template <class _Tp = void> struct _LIBCPP_TEMPLATE_VIS hash;

template <class>
struct __void_t { typedef void type; };
Expand Down Expand Up @@ -677,6 +677,19 @@ template <class _Tp> _LIBCPP_CONSTEXPR bool is_null_pointer_v
#endif
#endif

// is_character

template <class _Tp> struct is_character : public false_type {};
template<> struct is_character<char> : public true_type {};
template<> struct is_character<char16_t> : public true_type {};
template<> struct is_character<char32_t> : public true_type {};
template<> struct is_character<wchar_t> : public true_type {};

#if _LIBCPP_STD_VER > 14 && !defined(_LIBCPP_HAS_NO_VARIABLE_TEMPLATES)
template <class _Tp>
_LIBCPP_CONSTEXPR bool is_character_v = is_character<_Tp>::value;
#endif

// is_integral

template <class _Tp> struct __libcpp_is_integral : public false_type {};
Expand Down
35 changes: 34 additions & 1 deletion include/unordered_set
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,8 @@ template <class Value, class Hash, class Pred, class Alloc>

_LIBCPP_BEGIN_NAMESPACE_STD

template <class _Value, class _Hash = hash<_Value>, class _Pred = equal_to<_Value>,
template <class _Value, class _Hash = hash<_Value>,
class _Pred = equal_to<_Value>,
class _Alloc = allocator<_Value> >
class _LIBCPP_TEMPLATE_VIS unordered_set
{
Expand All @@ -351,6 +352,11 @@ private:
typedef __hash_table<value_type, hasher, key_equal, allocator_type> __table;

__table __table_;

template <class _K2, class _Rt>
using __enable_if_transparent = typename enable_if<
__is_transparent<_Hash, _K2>::value &&
__is_transparent<_Pred, _K2>::value, _Rt>::type;

public:
typedef typename __table::pointer pointer;
Expand Down Expand Up @@ -563,14 +569,41 @@ public:
iterator find(const key_type& __k) {return __table_.find(__k);}
_LIBCPP_INLINE_VISIBILITY
const_iterator find(const key_type& __k) const {return __table_.find(__k);}

template <class _K2>
_LIBCPP_INLINE_VISIBILITY
__enable_if_transparent<_K2, iterator>
find(const _K2& __k) {return __table_.find(__k);}

template <class _K2>
_LIBCPP_INLINE_VISIBILITY
__enable_if_transparent<_K2, const_iterator>
find(const _K2& __k) const {return __table_.find(__k);}

_LIBCPP_INLINE_VISIBILITY
size_type count(const key_type& __k) const {return __table_.__count_unique(__k);}

template <class _K2>
_LIBCPP_INLINE_VISIBILITY
__enable_if_transparent<_K2, size_type>
count(const _K2& __k) const {return __table_.__count_unique(__k);}

_LIBCPP_INLINE_VISIBILITY
pair<iterator, iterator> equal_range(const key_type& __k)
{return __table_.__equal_range_unique(__k);}
_LIBCPP_INLINE_VISIBILITY
pair<const_iterator, const_iterator> equal_range(const key_type& __k) const
{return __table_.__equal_range_unique(__k);}

template <class _K2>
_LIBCPP_INLINE_VISIBILITY
__enable_if_transparent<_K2, pair<iterator, iterator>>
equal_range(const _K2& __k) {return __table_.__equal_range_unique(__k);}

template <class _K2>
_LIBCPP_INLINE_VISIBILITY
__enable_if_transparent<_K2, pair<const_iterator, const_iterator>>
equal_range(const _K2& __k) const {return __table_.__equal_range_unique(__k);}

_LIBCPP_INLINE_VISIBILITY
size_type bucket_count() const _NOEXCEPT {return __table_.bucket_count();}
Expand Down
13 changes: 13 additions & 0 deletions include/utility
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,19 @@ using __enable_hash_helper = _Type;

#endif // !_LIBCPP_CXX03_LANG

template <>
struct _LIBCPP_TEMPLATE_VIS hash<void>
{
template <typename T>
_LIBCPP_INLINE_VISIBILITY
size_t operator()(T&& __v) const _NOEXCEPT {
return hash<typename remove_cv<
typename remove_reference<T>::type>::type>{}(__v);
}

typedef void is_transparent;
};

_LIBCPP_END_NAMESPACE_STD

#endif // _LIBCPP_UTILITY
13 changes: 13 additions & 0 deletions test/std/containers/unord/unord.set/count.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// size_type count(const key_type& k) const;

#include <unordered_set>
#include <string>
#include <cassert>

#include "min_allocator.h"
Expand Down Expand Up @@ -65,5 +66,17 @@ int main()
assert(c.count(50) == 1);
assert(c.count(5) == 0);
}
{
std::unordered_set<std::string, std::hash<>, std::equal_to<>> const s = {"one", "two", "three", "four"};
std::string_view str_v = "three";
assert(s.count(str_v) == 1);
assert(s.count("three") == 1);
char c1[] = "three";
char c2[] = {'t','h','r','e','e','\0'};
assert(s.count(c1) == 1);
assert(s.count(c2) == 1);
assert(s.count(std::string_view{"two"}) == 1);
assert(s.count(std::string_view{"TWO"}) == 0);
}
#endif
}
58 changes: 58 additions & 0 deletions test/std/containers/unord/unord.set/find_const.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,39 @@
// const_iterator find(const key_type& k) const;

#include <unordered_set>
#include <string>
#include <cassert>

#include "min_allocator.h"

struct dummy_int
{
static size_t counter;
int value;

dummy_int(): value{0} { ++counter; }
dummy_int(int v): value{v} { ++counter; }
dummy_int(dummy_int const& other): value(other.value) { ++counter; }
dummy_int(dummy_int&& other): value(other.value) { ++counter; }

operator int () const { return value; }
};

size_t dummy_int::counter = 0;

namespace std
{

template<>
struct hash<dummy_int>
{
size_t operator ()(dummy_int const& d) const {
return std::hash<int>{}(d.value);
}
};

}

int main()
{
{
Expand Down Expand Up @@ -63,5 +92,34 @@ int main()
i = c.find(5);
assert(i == c.cend());
}
{
dummy_int a[] = {10}; // +1

std::unordered_set<dummy_int> const c1{std::begin(a), std::end(a)}; // +1
assert(dummy_int::counter == 2);
auto i1 = c1.find(10); // +1
assert(i1 != c1.end());
assert(dummy_int::counter == 3);

std::unordered_set<dummy_int, std::hash<>, std::equal_to<>> const c2{std::begin(a), std::end(a)}; // +1
assert(dummy_int::counter == 4);
auto i2 = c2.find(10); // +0
assert(i2 != c2.end());
assert(dummy_int::counter == 4);
}
{
std::unordered_set<std::string, std::hash<>, std::equal_to<>> const s = {"one", "two", "three", "four"};
std::string_view str_v = "three";
std::string str3 = "three";
std::string str2 = "two";
assert(s.find(str_v) == s.find(str3));
assert(s.find("three") == s.find(str3));
char c1[] = "three";
char c2[] = {'t','h','r','e','e','\0'};
assert(s.find(c1) == s.find(str3));
assert(s.find(c2) == s.find(str3));
assert(s.find(std::string_view{"two"}) == s.find(str2));
assert(s.find(std::string_view{"TWO"}) == s.end());
}
#endif
}