Skip to content

Commit

Permalink
[core] Customized hash and eq for LRU cache (#48954)
Browse files Browse the repository at this point in the history
As titled, allow heterogeneous lookup and insertion; also add customized
hash and eq.

---------

Signed-off-by: hjiang <[email protected]>
  • Loading branch information
dentiny authored Dec 3, 2024
1 parent c339795 commit 24acab2
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 19 deletions.
29 changes: 16 additions & 13 deletions src/ray/util/shared_lru.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,17 @@
// // Check and consume `val`.
//
// TODO(hjiang):
// 1. Add template arguments for key hash and key equal, to pass into absl::flat_hash_map.
// 2. Provide a key hash wrapper to save a copy.
// 3. flat hash map supports heterogeneous lookup, expose `KeyLike` templated interface.
// 4. Add a `GetOrCreate` interface, which takes factory function to creation value.
// 5. For thread-safe cache, add a sharded container wrapper to reduce lock contention.
// 1. Write a wrapper around KeyHash and KeyEq, which takes std::reference_wrapper<Key>,
// so we could store keys only in std::list, and reference in absl::flat_hash_map.
// 2. Add a `GetOrCreate` interface, which takes factory function to creation value.
// 3. For thread-safe cache, add a sharded container wrapper to reduce lock contention.

#pragma once

#include <cstdint>
#include <list>
#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <utility>

Expand Down Expand Up @@ -90,7 +88,8 @@ class SharedLruCache final {
// Delete the entry with key `key`. Return true if the entry was found for
// `key`, false if the entry was not found. In both cases, there is no entry
// with key `key` existed after the call.
bool Delete(const Key &key) {
template <typename KeyLike>
bool Delete(KeyLike &&key) {
auto it = cache_.find(key);
if (it == cache_.end()) {
return false;
Expand All @@ -100,8 +99,10 @@ class SharedLruCache final {
return true;
}

// Look up the entry with key `key`. Return nullptr if key doesn't exist.
std::shared_ptr<Val> Get(const Key &key) {
// Look up the entry with key `key`. Return std::nullopt if key doesn't exist.
// If found, return a copy for the value.
template <typename KeyLike>
std::shared_ptr<Val> Get(KeyLike &&key) {
const auto cache_iter = cache_.find(key);
if (cache_iter == cache_.end()) {
return nullptr;
Expand Down Expand Up @@ -173,16 +174,18 @@ class ThreadSafeSharedLruCache final {
// Delete the entry with key `key`. Return true if the entry was found for
// `key`, false if the entry was not found. In both cases, there is no entry
// with key `key` existed after the call.
bool Delete(const Key &key) {
template <typename KeyLike>
bool Delete(KeyLike &&key) {
std::lock_guard lck(mu_);
return cache_.Delete(key);
return cache_.Delete(std::forward<KeyLike>(key));
}

// Look up the entry with key `key`. Return std::nullopt if key doesn't exist.
// If found, return a copy for the value.
std::shared_ptr<Val> Get(const Key &key) {
template <typename KeyLike>
std::shared_ptr<Val> Get(KeyLike &&key) {
std::lock_guard lck(mu_);
return cache_.Get(key);
return cache_.Get(std::forward<KeyLike>(key));
}

// Clear the cache.
Expand Down
33 changes: 27 additions & 6 deletions src/ray/util/tests/shared_lru_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ namespace ray::utils::container {

namespace {
constexpr size_t kTestCacheSz = 1;

class TestClassWithHashAndEq {
public:
TestClassWithHashAndEq(std::string data) : data_(std::move(data)) {}
bool operator==(const TestClassWithHashAndEq &rhs) const { return data_ == rhs.data_; }
template <typename H>
friend H AbslHashValue(H h, const TestClassWithHashAndEq &obj) {
return H::combine(std::move(h), obj.data_);
}

private:
std::string data_;
};
} // namespace

TEST(SharedLruCache, PutAndGet) {
Expand All @@ -34,22 +47,21 @@ TEST(SharedLruCache, PutAndGet) {

// Check put and get.
cache.Put("1", std::make_shared<std::string>("1"));
val = cache.Get("1");
val = cache.Get(std::string_view{"1"});
EXPECT_NE(val, nullptr);
EXPECT_EQ(*val, "1");

// Check key eviction.
cache.Put("2", std::make_shared<std::string>("2"));
val = cache.Get("1");
val = cache.Get(std::string_view{"1"});
EXPECT_EQ(val, nullptr);
val = cache.Get("2");
val = cache.Get(std::string_view{"2"});
EXPECT_NE(val, nullptr);
EXPECT_EQ(*val, "2");

// Check deletion.
EXPECT_FALSE(cache.Delete("1"));
EXPECT_TRUE(cache.Delete("2"));
val = cache.Get("2");
EXPECT_FALSE(cache.Delete(std::string_view{"1"}));
val = cache.Get(std::string_view{"1"});
EXPECT_EQ(val, nullptr);
}

Expand All @@ -73,4 +85,13 @@ TEST(SharedLruConstCache, TypeAliasAssertion) {
std::is_same_v<SharedLruConstCache<int, int>, SharedLruCache<int, const int>>);
}

TEST(SharedLruConstCache, CustomizedKey) {
TestClassWithHashAndEq obj1{"hello"};
TestClassWithHashAndEq obj2{"hello"};
SharedLruCache<TestClassWithHashAndEq, std::string> cache{2};
cache.Put(obj1, std::make_shared<std::string>("val"));
auto val = cache.Get(obj2);
EXPECT_EQ(*val, "val");
}

} // namespace ray::utils::container

0 comments on commit 24acab2

Please sign in to comment.