Skip to content

Commit

Permalink
use ref hash to avoid key copy
Browse files Browse the repository at this point in the history
Signed-off-by: hjiang <[email protected]>
  • Loading branch information
dentiny committed Dec 3, 2024
1 parent 24acab2 commit 8894533
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 8 deletions.
6 changes: 6 additions & 0 deletions src/ray/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,17 @@ cc_library(
visibility = ["//visibility:public"],
)

cc_library(
name = "map_utils",
hdrs = ["map_utils.h"],
)

cc_library(
name = "shared_lru",
hdrs = ["shared_lru.h"],
visibility = ["//visibility:public"],
deps = [
":map_utils",
":util",
"@com_google_absl//absl/container:flat_hash_map",
],
Expand Down
63 changes: 63 additions & 0 deletions src/ray/util/map_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Hash utils.

#pragma once

#include <functional>
#include <utility>

template <typename Hash>
struct RefHash : Hash {
RefHash() = default;
template <typename H>
RefHash(H &&h) : Hash(std::forward<H>(h)) {} // NOLINT

RefHash(const RefHash &) = default;
RefHash(RefHash &&) noexcept = default;
RefHash &operator=(const RefHash &) = default;
RefHash &operator=(RefHash &&) noexcept = default;

template <typename T>
size_t operator()(std::reference_wrapper<const T> val) const {
return Hash::operator()(val.get());
}
template <typename T>
size_t operator()(const T &val) const {
return Hash::operator()(val);
}
};

template <typename Hash>
RefHash(Hash &&) -> RefHash<std::remove_reference_t<Hash>>;

template <typename Equal>
struct RefEq : Equal {
RefEq() = default;
template <typename Eq>
RefEq(Eq &&eq) : Equal(std::forward<Eq>(eq)) {} // NOLINT

RefEq(const RefEq &) = default;
RefEq(RefEq &&) noexcept = default;
RefEq &operator=(const RefEq &) = default;
RefEq &operator=(RefEq &&) noexcept = default;

template <typename T1, typename T2>
bool operator()(std::reference_wrapper<const T1> lhs,
std::reference_wrapper<const T2> rhs) const {
return Equal::operator()(lhs.get(), rhs.get());
}
template <typename T1, typename T2>
bool operator()(const T1 &lhs, std::reference_wrapper<const T2> rhs) const {
return Equal::operator()(lhs, rhs.get());
}
template <typename T1, typename T2>
bool operator()(std::reference_wrapper<const T1> lhs, const T2 &rhs) const {
return Equal::operator()(lhs.get(), rhs);
}
template <typename T1, typename T2>
bool operator()(const T1 &lhs, const T2 &rhs) const {
return Equal::operator()(lhs, rhs);
}
};

template <typename Equal>
RefEq(Equal &&) -> RefEq<std::remove_reference_t<Equal>>;
22 changes: 14 additions & 8 deletions src/ray/util/shared_lru.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@
// // Check and consume `val`.
//
// TODO(hjiang):
// 1. Write a wrapper around KeyHash and KeyEq, which takes std::reference_wrapper<Key>,
// so we could store keys only in std::list, and reference in absl::flat_hash_map.
// 2. Add a `GetOrCreate` interface, which takes factory function to creation value.
// 3. For thread-safe cache, add a sharded container wrapper to reduce lock contention.
// 1. Add a `GetOrCreate` interface, which takes factory function to creation value.
// 2. For thread-safe cache, add a sharded container wrapper to reduce lock contention.

#pragma once

Expand All @@ -43,6 +41,7 @@

#include "absl/container/flat_hash_map.h"
#include "src/ray/util/logging.h"
#include "src/ray/util/map_utils.h"

namespace ray::utils::container {

Expand All @@ -65,7 +64,7 @@ class SharedLruCache final {
// the same key.
void Put(Key key, std::shared_ptr<Val> value) {
RAY_CHECK(value != nullptr);
auto iter = cache_.find(key);
auto iter = cache_.find(std::cref(key));
if (iter != cache_.end()) {
lru_list_.splice(lru_list_.begin(), lru_list_, iter->second.lru_iterator);
iter->second.value = std::move(value);
Expand All @@ -74,7 +73,7 @@ class SharedLruCache final {

lru_list_.emplace_front(key);
Entry new_entry{std::move(value), lru_list_.begin()};
cache_[std::move(key)] = std::move(new_entry);
cache_[std::cref(lru_list_.front())] = std::move(new_entry);

if (max_entries_ > 0 && lru_list_.size() > max_entries_) {
const auto &stale_key = lru_list_.back();
Expand All @@ -90,7 +89,7 @@ class SharedLruCache final {
// with key `key` existed after the call.
template <typename KeyLike>
bool Delete(KeyLike &&key) {
auto it = cache_.find(key);
auto it = cache_.find(std::cref(key));
if (it == cache_.end()) {
return false;
}
Expand Down Expand Up @@ -129,7 +128,14 @@ class SharedLruCache final {
typename std::list<Key>::iterator lru_iterator;
};

using EntryMap = absl::flat_hash_map<Key, Entry>;
// TODO(hjiang): These two internal type alias has been consolidated into stable header
// in later versions, update after we bump up abseil.
using KeyHash = absl::container_internal::hash_default_hash<Key>;
using KeyEqual = absl::container_internal::hash_default_eq<Key>;

using KeyConstRef = std::reference_wrapper<const Key>;
using EntryMap =
absl::flat_hash_map<KeyConstRef, Entry, RefHash<KeyHash>, RefEq<KeyEqual>>;

// The maximum number of entries in the cache. A value of 0 means there is no
// limit on entry count.
Expand Down

0 comments on commit 8894533

Please sign in to comment.