Skip to content

Commit

Permalink
use lru cache
Browse files Browse the repository at this point in the history
Signed-off-by: hjiang <[email protected]>
  • Loading branch information
dentiny committed Nov 26, 2024
1 parent 321dc87 commit f3db2bc
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 33 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ ray_cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:node_hash_map",
"@nlohmann_json",
"@boost//:compute",
],
)

Expand Down
2 changes: 2 additions & 0 deletions release/benchmarks/distributed/test_many_tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODO(hjiang): Make a benchmark for release on task submission.

import click
import ray
import ray._private.test_utils as test_utils
Expand Down
58 changes: 32 additions & 26 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,9 @@ using json = nlohmann::json;

namespace ray::core {

JobID GetProcessJobID(const CoreWorkerOptions &options) {
if (options.worker_type == WorkerType::DRIVER) {
RAY_CHECK(!options.job_id.IsNil());
} else {
RAY_CHECK(options.job_id.IsNil());
}

if (options.worker_type == WorkerType::WORKER) {
// For workers, the job ID is assigned by Raylet via an environment variable.
const std::string &job_id_env = RayConfig::instance().JOB_ID();
RAY_CHECK(!job_id_env.empty());
return JobID::FromHex(job_id_env);
}
return options.job_id;
}

namespace {
// Default capacity for serialization caches.
constexpr size_t kDefaultSerializationCacheCap = 500;

// Implements setting the transient RUNNING_IN_RAY_GET and RUNNING_IN_RAY_WAIT states.
// These states override the RUNNING state of a task.
Expand Down Expand Up @@ -126,6 +112,22 @@ std::optional<ObjectLocation> TryGetLocalObjectLocation(

} // namespace

JobID GetProcessJobID(const CoreWorkerOptions &options) {
if (options.worker_type == WorkerType::DRIVER) {
RAY_CHECK(!options.job_id.IsNil());
} else {
RAY_CHECK(options.job_id.IsNil());
}

if (options.worker_type == WorkerType::WORKER) {
// For workers, the job ID is assigned by Raylet via an environment variable.
const std::string &job_id_env = RayConfig::instance().JOB_ID();
RAY_CHECK(!job_id_env.empty());
return JobID::FromHex(job_id_env);
}
return options.job_id;
}

TaskCounter::TaskCounter() {
counter_.SetOnChangeCallback(
[this](const std::tuple<std::string, TaskStatusType, bool>
Expand Down Expand Up @@ -262,7 +264,9 @@ CoreWorker::CoreWorker(CoreWorkerOptions options, const WorkerID &worker_id)
grpc_service_(io_service_, *this),
task_execution_service_work_(task_execution_service_),
exiting_detail_(std::nullopt),
pid_(getpid()) {
pid_(getpid()),
runtime_env_pb_serialization_cache_(kDefaultSerializationCacheCap),
runtime_env_json_serialization_cache_(kDefaultSerializationCacheCap) {
// Notify that core worker is initialized.
auto initialzed_scope_guard = absl::MakeCleanup([this] {
absl::MutexLock lock(&initialize_mutex_);
Expand Down Expand Up @@ -2163,9 +2167,10 @@ std::shared_ptr<rpc::RuntimeEnvInfo> CoreWorker::GetCachedPbRuntimeEnvOrParse(
const std::string &serialized_runtime_env_info) const {
{
std::lock_guard lck(runtime_env_serialization_mutex_);
auto iter = runtime_env_pb_serialization_cache_.find(serialized_runtime_env_info);
if (iter != runtime_env_pb_serialization_cache_.end()) {
return iter->second;
auto opt_runtime_info =
runtime_env_pb_serialization_cache_.get(serialized_runtime_env_info);
if (opt_runtime_info.has_value()) {
return *opt_runtime_info;
}
}
auto pb_runtime_env_info = std::make_shared<rpc::RuntimeEnvInfo>();
Expand All @@ -2174,8 +2179,8 @@ std::shared_ptr<rpc::RuntimeEnvInfo> CoreWorker::GetCachedPbRuntimeEnvOrParse(
.ok());
{
std::lock_guard lck(runtime_env_serialization_mutex_);
runtime_env_pb_serialization_cache_.emplace(serialized_runtime_env_info,
pb_runtime_env_info);
runtime_env_pb_serialization_cache_.insert(serialized_runtime_env_info,
pb_runtime_env_info);
}
return pb_runtime_env_info;
}
Expand All @@ -2184,16 +2189,17 @@ std::shared_ptr<nlohmann::json> CoreWorker::GetCachedJsonRuntimeEnvOrParse(
const std::string &serialized_runtime_env) const {
{
std::lock_guard lck(runtime_env_serialization_mutex_);
auto iter = runtime_env_json_serialization_cache_.find(serialized_runtime_env);
if (iter != runtime_env_json_serialization_cache_.end()) {
return iter->second;
auto opt_runtime_info =
runtime_env_json_serialization_cache_.get(serialized_runtime_env);
if (opt_runtime_info.has_value()) {
return *opt_runtime_info;
}
}
auto parsed_json = std::make_shared<json>();
*parsed_json = json::parse(serialized_runtime_env);
{
std::lock_guard lck(runtime_env_serialization_mutex_);
runtime_env_json_serialization_cache_.emplace(serialized_runtime_env, parsed_json);
runtime_env_json_serialization_cache_.insert(serialized_runtime_env, parsed_json);
}
return parsed_json;
}
Expand Down
11 changes: 4 additions & 7 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "absl/base/optimization.h"
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h"
#include "boost/compute/detail/lru_cache.hpp"
#include "ray/common/asio/periodical_runner.h"
#include "ray/common/buffer.h"
#include "ray/common/placement_group.h"
Expand Down Expand Up @@ -1859,17 +1860,13 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
absl::flat_hash_set<ObjectID> deleted_generator_ids_;

/// Serialized runtime info env are cached.
/// TODO(hjiang):
/// 1. Better to use hash value for serialized runtime (key), or cap a max serialized
/// string length to avoid too much memory consumption.
/// 2. Implement a LRU cache, to cap max number of key-value pairs to limit max memory
/// consumption.
mutable std::mutex runtime_env_serialization_mutex_;
/// Maps serialized runtime env to **immutable** deserialized protobuf.
mutable std::unordered_map<std::string, std::shared_ptr<rpc::RuntimeEnvInfo>>
mutable boost::compute::detail::lru_cache<std::string,
std::shared_ptr<rpc::RuntimeEnvInfo>>
runtime_env_pb_serialization_cache_;
/// Maps serialized runtime env to **immutable** deserialized json.
mutable std::unordered_map<std::string, std::shared_ptr<nlohmann::json>>
mutable boost::compute::detail::lru_cache<std::string, std::shared_ptr<nlohmann::json>>
runtime_env_json_serialization_cache_;
};

Expand Down

0 comments on commit f3db2bc

Please sign in to comment.