diff --git a/src/ray/util/BUILD b/src/ray/util/BUILD index 87f8a57e8dea..ea91192f4ff6 100644 --- a/src/ray/util/BUILD +++ b/src/ray/util/BUILD @@ -56,11 +56,17 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "map_utils", + hdrs = ["map_utils.h"], +) + cc_library( name = "shared_lru", hdrs = ["shared_lru.h"], visibility = ["//visibility:public"], deps = [ + ":map_utils", ":util", "@com_google_absl//absl/container:flat_hash_map", ], diff --git a/src/ray/util/map_utils.h b/src/ray/util/map_utils.h new file mode 100644 index 000000000000..ddf14c045431 --- /dev/null +++ b/src/ray/util/map_utils.h @@ -0,0 +1,63 @@ +// Hash utils. + +#pragma once + +#include +#include + +template +struct RefHash : Hash { + RefHash() = default; + template + RefHash(H &&h) : Hash(std::forward(h)) {} // NOLINT + + RefHash(const RefHash &) = default; + RefHash(RefHash &&) noexcept = default; + RefHash &operator=(const RefHash &) = default; + RefHash &operator=(RefHash &&) noexcept = default; + + template + size_t operator()(std::reference_wrapper val) const { + return Hash::operator()(val.get()); + } + template + size_t operator()(const T &val) const { + return Hash::operator()(val); + } +}; + +template +RefHash(Hash &&) -> RefHash>; + +template +struct RefEq : Equal { + RefEq() = default; + template + RefEq(Eq &&eq) : Equal(std::forward(eq)) {} // NOLINT + + RefEq(const RefEq &) = default; + RefEq(RefEq &&) noexcept = default; + RefEq &operator=(const RefEq &) = default; + RefEq &operator=(RefEq &&) noexcept = default; + + template + bool operator()(std::reference_wrapper lhs, + std::reference_wrapper rhs) const { + return Equal::operator()(lhs.get(), rhs.get()); + } + template + bool operator()(const T1 &lhs, std::reference_wrapper rhs) const { + return Equal::operator()(lhs, rhs.get()); + } + template + bool operator()(std::reference_wrapper lhs, const T2 &rhs) const { + return Equal::operator()(lhs.get(), rhs); + } + template + bool operator()(const T1 &lhs, const T2 &rhs) const { + return Equal::operator()(lhs, rhs); + } +}; + +template +RefEq(Equal &&) -> RefEq>; diff --git a/src/ray/util/shared_lru.h b/src/ray/util/shared_lru.h index 39c43231ac4d..1fef1345f5c8 100644 --- a/src/ray/util/shared_lru.h +++ b/src/ray/util/shared_lru.h @@ -43,6 +43,7 @@ #include "absl/container/flat_hash_map.h" #include "src/ray/util/logging.h" +#include "src/ray/util/map_utils.h" namespace ray::utils::container { @@ -65,7 +66,7 @@ class SharedLruCache final { // the same key. void Put(Key key, std::shared_ptr value) { RAY_CHECK(value != nullptr); - auto iter = cache_.find(key); + auto iter = cache_.find(std::cref(key)); if (iter != cache_.end()) { lru_list_.splice(lru_list_.begin(), lru_list_, iter->second.lru_iterator); iter->second.value = std::move(value); @@ -74,7 +75,7 @@ class SharedLruCache final { lru_list_.emplace_front(key); Entry new_entry{std::move(value), lru_list_.begin()}; - cache_[std::move(key)] = std::move(new_entry); + cache_[std::cref(lru_list_.front())] = std::move(new_entry); if (max_entries_ > 0 && lru_list_.size() > max_entries_) { const auto &stale_key = lru_list_.back(); @@ -90,7 +91,7 @@ class SharedLruCache final { // with key `key` existed after the call. template bool Delete(KeyLike &&key) { - auto it = cache_.find(key); + auto it = cache_.find(std::cref(key)); if (it == cache_.end()) { return false; } @@ -129,7 +130,14 @@ class SharedLruCache final { typename std::list::iterator lru_iterator; }; - using EntryMap = absl::flat_hash_map; + // TODO(hjiang): These two internal type alias has been consolidated into stable header + // in later versions, update after we bump up abseil. + using KeyHash = absl::container_internal::hash_default_hash; + using KeyEqual = absl::container_internal::hash_default_eq; + + using KeyConstRef = std::reference_wrapper; + using EntryMap = + absl::flat_hash_map, RefEq>; // The maximum number of entries in the cache. A value of 0 means there is no // limit on entry count.