diff --git a/src/ray/util/shared_lru.h b/src/ray/util/shared_lru.h index 8132e38b6f12..39c43231ac4d 100644 --- a/src/ray/util/shared_lru.h +++ b/src/ray/util/shared_lru.h @@ -27,11 +27,10 @@ // // Check and consume `val`. // // TODO(hjiang): -// 1. Add template arguments for key hash and key equal, to pass into absl::flat_hash_map. -// 2. Provide a key hash wrapper to save a copy. -// 3. flat hash map supports heterogeneous lookup, expose `KeyLike` templated interface. -// 4. Add a `GetOrCreate` interface, which takes factory function to creation value. -// 5. For thread-safe cache, add a sharded container wrapper to reduce lock contention. +// 1. Write a wrapper around KeyHash and KeyEq, which takes std::reference_wrapper, +// so we could store keys only in std::list, and reference in absl::flat_hash_map. +// 2. Add a `GetOrCreate` interface, which takes factory function to creation value. +// 3. For thread-safe cache, add a sharded container wrapper to reduce lock contention. #pragma once @@ -39,7 +38,6 @@ #include #include #include -#include #include #include @@ -90,7 +88,8 @@ class SharedLruCache final { // Delete the entry with key `key`. Return true if the entry was found for // `key`, false if the entry was not found. In both cases, there is no entry // with key `key` existed after the call. - bool Delete(const Key &key) { + template + bool Delete(KeyLike &&key) { auto it = cache_.find(key); if (it == cache_.end()) { return false; @@ -100,8 +99,10 @@ class SharedLruCache final { return true; } - // Look up the entry with key `key`. Return nullptr if key doesn't exist. - std::shared_ptr Get(const Key &key) { + // Look up the entry with key `key`. Return std::nullopt if key doesn't exist. + // If found, return a copy for the value. + template + std::shared_ptr Get(KeyLike &&key) { const auto cache_iter = cache_.find(key); if (cache_iter == cache_.end()) { return nullptr; @@ -173,16 +174,18 @@ class ThreadSafeSharedLruCache final { // Delete the entry with key `key`. Return true if the entry was found for // `key`, false if the entry was not found. In both cases, there is no entry // with key `key` existed after the call. - bool Delete(const Key &key) { + template + bool Delete(KeyLike &&key) { std::lock_guard lck(mu_); - return cache_.Delete(key); + return cache_.Delete(std::forward(key)); } // Look up the entry with key `key`. Return std::nullopt if key doesn't exist. // If found, return a copy for the value. - std::shared_ptr Get(const Key &key) { + template + std::shared_ptr Get(KeyLike &&key) { std::lock_guard lck(mu_); - return cache_.Get(key); + return cache_.Get(std::forward(key)); } // Clear the cache. diff --git a/src/ray/util/tests/shared_lru_test.cc b/src/ray/util/tests/shared_lru_test.cc index 7c47f4d1daf0..673395a82f3b 100644 --- a/src/ray/util/tests/shared_lru_test.cc +++ b/src/ray/util/tests/shared_lru_test.cc @@ -23,6 +23,19 @@ namespace ray::utils::container { namespace { constexpr size_t kTestCacheSz = 1; + +class TestClassWithHashAndEq { + public: + TestClassWithHashAndEq(std::string data) : data_(std::move(data)) {} + bool operator==(const TestClassWithHashAndEq &rhs) const { return data_ == rhs.data_; } + template + friend H AbslHashValue(H h, const TestClassWithHashAndEq &obj) { + return H::combine(std::move(h), obj.data_); + } + + private: + std::string data_; +}; } // namespace TEST(SharedLruCache, PutAndGet) { @@ -34,22 +47,21 @@ TEST(SharedLruCache, PutAndGet) { // Check put and get. cache.Put("1", std::make_shared("1")); - val = cache.Get("1"); + val = cache.Get(std::string_view{"1"}); EXPECT_NE(val, nullptr); EXPECT_EQ(*val, "1"); // Check key eviction. cache.Put("2", std::make_shared("2")); - val = cache.Get("1"); + val = cache.Get(std::string_view{"1"}); EXPECT_EQ(val, nullptr); - val = cache.Get("2"); + val = cache.Get(std::string_view{"2"}); EXPECT_NE(val, nullptr); EXPECT_EQ(*val, "2"); // Check deletion. - EXPECT_FALSE(cache.Delete("1")); - EXPECT_TRUE(cache.Delete("2")); - val = cache.Get("2"); + EXPECT_FALSE(cache.Delete(std::string_view{"1"})); + val = cache.Get(std::string_view{"1"}); EXPECT_EQ(val, nullptr); } @@ -73,4 +85,13 @@ TEST(SharedLruConstCache, TypeAliasAssertion) { std::is_same_v, SharedLruCache>); } +TEST(SharedLruConstCache, CustomizedKey) { + TestClassWithHashAndEq obj1{"hello"}; + TestClassWithHashAndEq obj2{"hello"}; + SharedLruCache cache{2}; + cache.Put(obj1, std::make_shared("val")); + auto val = cache.Get(obj2); + EXPECT_EQ(*val, "val"); +} + } // namespace ray::utils::container