From bcf81c85000c9739123338bc46607163ea9e7847 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 18 Sep 2024 13:53:21 -0700 Subject: [PATCH 01/26] dedicated kv ioctx Signed-off-by: Ruiyang Wang --- src/ray/common/asio/asio_util.h | 46 ++++++++++++ src/ray/gcs/gcs_client/gcs_client.cc | 34 +-------- src/ray/gcs/gcs_server/gcs_server.cc | 75 +++++++++++-------- src/ray/gcs/gcs_server/gcs_server.h | 13 +++- src/ray/gcs/gcs_server/gcs_task_manager.cc | 7 -- src/ray/gcs/gcs_server/gcs_task_manager.h | 31 ++------ src/ray/gcs/gcs_server/pubsub_handler.cc | 13 ---- src/ray/gcs/gcs_server/pubsub_handler.h | 4 - src/ray/gcs/redis_context.cc | 6 +- .../store_client/observable_store_client.cc | 40 +++++----- .../gcs/store_client/redis_store_client.cc | 2 +- 11 files changed, 130 insertions(+), 141 deletions(-) diff --git a/src/ray/common/asio/asio_util.h b/src/ray/common/asio/asio_util.h index 232e397e5ce7..6dc8bbe8cd8b 100644 --- a/src/ray/common/asio/asio_util.h +++ b/src/ray/common/asio/asio_util.h @@ -16,8 +16,10 @@ #include #include +#include #include "ray/common/asio/instrumented_io_context.h" +#include "ray/util/util.h" template std::shared_ptr execute_after( @@ -37,3 +39,47 @@ std::shared_ptr execute_after( return timer; } + +/** + * A class that manages an instrumented_io_context and a std::thread. + * The constructor takes a thread name and starts the thread. + * The destructor stops the io_service and joins the thread. + */ +class InstrumentedIoContextWithThread { + public: + /** + * Constructor. + * @param thread_name The name of the thread. + */ + explicit InstrumentedIoContextWithThread(const std::string &thread_name) + : io_service_(), work_(io_service_) { + io_thread_ = std::thread([this, thread_name] { + SetThreadName(thread_name); + io_service_.run(); + }); + } + + ~InstrumentedIoContextWithThread() { Stop(); } + + // Non-movable and non-copyable. + InstrumentedIoContextWithThread(const InstrumentedIoContextWithThread &) = delete; + InstrumentedIoContextWithThread &operator=(const InstrumentedIoContextWithThread &) = + delete; + InstrumentedIoContextWithThread(InstrumentedIoContextWithThread &&) = delete; + InstrumentedIoContextWithThread &operator=(InstrumentedIoContextWithThread &&) = delete; + + instrumented_io_context &GetIoService() { return io_service_; } + + // Idempotent. Once it's stopped you can't restart it. + void Stop() { + io_service_.stop(); + if (io_thread_.joinable()) { + io_thread_.join(); + } + } + + private: + instrumented_io_context io_service_; + boost::asio::io_service::work work_; // to keep io_service_ running + std::thread io_thread_; +}; diff --git a/src/ray/gcs/gcs_client/gcs_client.cc b/src/ray/gcs/gcs_client/gcs_client.cc index cb734bc1dca1..8825dd0ab73f 100644 --- a/src/ray/gcs/gcs_client/gcs_client.cc +++ b/src/ray/gcs/gcs_client/gcs_client.cc @@ -18,6 +18,7 @@ #include #include +#include "ray/common/asio/asio_util.h" #include "ray/common/ray_config.h" #include "ray/gcs/gcs_client/accessor.h" #include "ray/pubsub/subscriber.h" @@ -717,38 +718,9 @@ std::unordered_map PythonGetNodeLabels( node_info.labels().end()); } -/// Creates a singleton thread that runs an io_service. -/// All ConnectToGcsStandalone calls will share this io_service. -class SingletonIoContext { - public: - static SingletonIoContext &Instance() { - static SingletonIoContext instance; - return instance; - } - - instrumented_io_context &GetIoService() { return io_service_; } - - private: - SingletonIoContext() : work_(io_service_) { - io_thread_ = std::thread([this] { - SetThreadName("singleton_io_context.gcs_client"); - io_service_.run(); - }); - } - ~SingletonIoContext() { - io_service_.stop(); - if (io_thread_.joinable()) { - io_thread_.join(); - } - } - - instrumented_io_context io_service_; - boost::asio::io_service::work work_; // to keep io_service_ running - std::thread io_thread_; -}; - Status ConnectOnSingletonIoContext(GcsClient &gcs_client, int64_t timeout_ms) { - instrumented_io_context &io_service = SingletonIoContext::Instance().GetIoService(); + static InstrumentedIoContextWithThread io_context("gcs_client_io_service"); + instrumented_io_context &io_service = io_context.GetIoService(); return gcs_client.Connect(io_service, timeout_ms); } diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 88708b005e6a..f9601fcae7f8 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -54,6 +54,10 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, : config_(config), storage_type_(GetStorageType()), main_service_(main_service), + pubsub_io_context_("pubsub_io_context"), + kv_io_context_("kv_io_context"), + task_io_context_("task_io_context"), + ray_syncer_io_context_("ray_syncer_io_context"), rpc_server_(config.grpc_server_name, config.grpc_server_port, config.node_ip_address == "127.0.0.1", @@ -65,7 +69,7 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, RayConfig::instance().gcs_server_rpc_client_thread_num()), raylet_client_pool_( std::make_shared(client_call_manager_)), - pubsub_periodical_runner_(pubsub_io_service_), + pubsub_periodical_runner_(pubsub_io_context_.GetIoService()), periodical_runner_(main_service), is_started_(false), is_stopped_(false) { @@ -73,10 +77,12 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, RAY_LOG(INFO) << "GCS storage type is " << storage_type_; switch (storage_type_) { case StorageType::IN_MEMORY: - gcs_table_storage_ = std::make_shared(main_service_); + gcs_table_storage_ = + std::make_shared(kv_io_context_.GetIoService()); break; case StorageType::REDIS_PERSIST: - gcs_table_storage_ = std::make_shared(GetOrConnectRedis()); + gcs_table_storage_ = std::make_shared( + GetOrConnectRedis(kv_io_context_.GetIoService())); break; default: RAY_LOG(FATAL) << "Unexpected storage type: " << storage_type_; @@ -264,14 +270,11 @@ void GcsServer::DoStart(const GcsInitData &gcs_init_data) { void GcsServer::Stop() { if (!is_stopped_) { RAY_LOG(INFO) << "Stopping GCS server."; - ray_syncer_io_context_.stop(); - ray_syncer_thread_->join(); - ray_syncer_.reset(); - gcs_task_manager_->Stop(); - - pubsub_handler_->Stop(); - pubsub_handler_.reset(); + ray_syncer_io_context_.Stop(); + task_io_context_.Stop(); + kv_io_context_.Stop(); + pubsub_io_context_.Stop(); // Shutdown the rpc server rpc_server_.Shutdown(); @@ -531,16 +534,12 @@ GcsServer::StorageType GcsServer::GetStorageType() const { } void GcsServer::InitRaySyncer(const GcsInitData &gcs_init_data) { - ray_syncer_ = - std::make_unique(ray_syncer_io_context_, kGCSNodeID.Binary()); + ray_syncer_ = std::make_unique(ray_syncer_io_context_.GetIoService(), + kGCSNodeID.Binary()); ray_syncer_->Register( syncer::MessageType::RESOURCE_VIEW, nullptr, gcs_resource_manager_.get()); ray_syncer_->Register( syncer::MessageType::COMMANDS, nullptr, gcs_resource_manager_.get()); - ray_syncer_thread_ = std::make_unique([this]() { - boost::asio::io_service::work work(ray_syncer_io_context_); - ray_syncer_io_context_.run(); - }); ray_syncer_service_ = std::make_unique(*ray_syncer_); rpc_server_.RegisterService(*ray_syncer_service_); } @@ -563,13 +562,13 @@ void GcsServer::InitKVManager() { std::unique_ptr instance; switch (storage_type_) { case (StorageType::REDIS_PERSIST): - instance = std::make_unique( - std::make_unique(GetOrConnectRedis())); + instance = std::make_unique(std::make_unique( + GetOrConnectRedis(kv_io_context_.GetIoService()))); break; case (StorageType::IN_MEMORY): instance = std::make_unique(std::make_unique( - std::make_unique(main_service_))); + std::make_unique(kv_io_context_.GetIoService()))); break; default: RAY_LOG(FATAL) << "Unexpected storage type! " << storage_type_; @@ -580,16 +579,17 @@ void GcsServer::InitKVManager() { void GcsServer::InitKVService() { RAY_CHECK(kv_manager_); - kv_service_ = std::make_unique(main_service_, *kv_manager_); + kv_service_ = std::make_unique( + kv_io_context_.GetIoService(), *kv_manager_); // Register service. rpc_server_.RegisterService(*kv_service_, false /* token_auth */); } void GcsServer::InitPubSubHandler() { - pubsub_handler_ = - std::make_unique(pubsub_io_service_, gcs_publisher_); - pubsub_service_ = std::make_unique(pubsub_io_service_, - *pubsub_handler_); + pubsub_handler_ = std::make_unique( + pubsub_io_context_.GetIoService(), gcs_publisher_); + pubsub_service_ = std::make_unique( + pubsub_io_context_.GetIoService(), *pubsub_handler_); // Register service. rpc_server_.RegisterService(*pubsub_service_); } @@ -683,10 +683,10 @@ void GcsServer::InitGcsAutoscalerStateManager(const GcsInitData &gcs_init_data) } void GcsServer::InitGcsTaskManager() { - gcs_task_manager_ = std::make_unique(); + gcs_task_manager_ = std::make_unique(task_io_context_.GetIoService()); // Register service. - task_info_service_.reset(new rpc::TaskInfoGrpcService(gcs_task_manager_->GetIoContext(), - *gcs_task_manager_)); + task_info_service_.reset( + new rpc::TaskInfoGrpcService(task_io_context_.GetIoService(), *gcs_task_manager_)); rpc_server_.RegisterService(*task_info_service_); } @@ -819,15 +819,16 @@ std::string GcsServer::GetDebugState() const { return stream.str(); } -std::shared_ptr GcsServer::GetOrConnectRedis() { +std::shared_ptr GcsServer::GetOrConnectRedis( + instrumented_io_context &io_service) { if (redis_client_ == nullptr) { redis_client_ = std::make_shared(GetRedisClientOptions()); - auto status = redis_client_->Connect(main_service_); + auto status = redis_client_->Connect(io_service); RAY_CHECK(status.ok()) << "Failed to init redis gcs client as " << status; // Init redis failure detector. gcs_redis_failure_detector_ = - std::make_shared(main_service_, redis_client_, []() { + std::make_shared(io_service, redis_client_, []() { RAY_LOG(FATAL) << "Redis connection failed. Shutdown GCS."; }); gcs_redis_failure_detector_->Start(); @@ -840,9 +841,17 @@ void GcsServer::PrintAsioStats() { const auto event_stats_print_interval_ms = RayConfig::instance().event_stats_print_interval_ms(); if (event_stats_print_interval_ms != -1 && RayConfig::instance().event_stats()) { - RAY_LOG(INFO) << "Event stats:\n\n" << main_service_.stats().StatsString() << "\n\n"; - RAY_LOG(INFO) << "GcsTaskManager Event stats:\n\n" - << gcs_task_manager_->GetIoContext().stats().StatsString() << "\n\n"; + RAY_LOG(INFO) << "main_service_ Event stats:\n\n" + << main_service_.stats().StatsString() << "\n\n"; + RAY_LOG(INFO) << "pubsub_io_context_ Event stats:\n\n" + << pubsub_io_context_.GetIoService().stats().StatsString() << "\n\n"; + RAY_LOG(INFO) << "kv_io_context_ Event stats:\n\n" + << kv_io_context_.GetIoService().stats().StatsString() << "\n\n"; + RAY_LOG(INFO) << "task_io_context_ Event stats:\n\n" + << task_io_context_.GetIoService().stats().StatsString() << "\n\n"; + RAY_LOG(INFO) << "ray_syncer_io_context_ Event stats:\n\n" + << ray_syncer_io_context_.GetIoService().stats().StatsString() + << "\n\n"; } } diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index 99296ee6d84b..2cf994cddb2f 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -14,6 +14,7 @@ #pragma once +#include "ray/common/asio/asio_util.h" #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/ray_syncer/ray_syncer.h" #include "ray/common/runtime_env_manager.h" @@ -201,7 +202,7 @@ class GcsServer { void PrintAsioStats(); /// Get or connect to a redis server - std::shared_ptr GetOrConnectRedis(); + std::shared_ptr GetOrConnectRedis(instrumented_io_context &io_service); void TryGlobalGC(); @@ -212,7 +213,13 @@ class GcsServer { /// The main io service to drive event posted from grpc threads. instrumented_io_context &main_service_; /// The io service used by Pubsub, for isolation from other workload. - instrumented_io_context pubsub_io_service_; + InstrumentedIoContextWithThread pubsub_io_context_; + // The io service used by internal KV service, table storage and the StoreClient. + InstrumentedIoContextWithThread kv_io_context_; + // The io service used by task manager. + InstrumentedIoContextWithThread task_io_context_; + // The io service used by ray syncer. + InstrumentedIoContextWithThread ray_syncer_io_context_; /// The grpc server rpc::GrpcServer rpc_server_; /// The `ClientCallManager` object that is shared by all `NodeManagerWorkerClient`s. @@ -254,8 +261,6 @@ class GcsServer { /// Ray Syncer related fields. std::unique_ptr ray_syncer_; std::unique_ptr ray_syncer_service_; - std::unique_ptr ray_syncer_thread_; - instrumented_io_context ray_syncer_io_context_; /// The node id of GCS. NodeID gcs_node_id_; diff --git a/src/ray/gcs/gcs_server/gcs_task_manager.cc b/src/ray/gcs/gcs_server/gcs_task_manager.cc index 33a8bb6ded86..38b631a78545 100644 --- a/src/ray/gcs/gcs_server/gcs_task_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_task_manager.cc @@ -21,13 +21,6 @@ namespace ray { namespace gcs { -void GcsTaskManager::Stop() { - io_service_.stop(); - if (io_service_thread_->joinable()) { - io_service_thread_->join(); - } -} - std::vector GcsTaskManager::GcsTaskManagerStorage::GetTaskEvents() const { std::vector ret; diff --git a/src/ray/gcs/gcs_server/gcs_task_manager.h b/src/ray/gcs/gcs_server/gcs_task_manager.h index 9c8e8c215d8d..1e87baec43b3 100644 --- a/src/ray/gcs/gcs_server/gcs_task_manager.h +++ b/src/ray/gcs/gcs_server/gcs_task_manager.h @@ -86,18 +86,13 @@ class FinishedTaskActorTaskGcPolicy : public TaskEventsGcPolicyInterface { class GcsTaskManager : public rpc::TaskInfoHandler { public: /// Create a GcsTaskManager. - GcsTaskManager() - : stats_counter_(), + explicit GcsTaskManager(instrumented_io_context &io_service) + : io_service_(io_service), + stats_counter_(), task_event_storage_(std::make_unique( RayConfig::instance().task_events_max_num_task_in_gcs(), stats_counter_, std::make_unique())), - io_service_thread_(std::make_unique([this] { - SetThreadName("task_events"); - // Keep io_service_ alive. - boost::asio::io_service::work io_service_work_(io_service_); - io_service_.run(); - })), periodical_runner_(io_service_) { periodical_runner_.RunFnPeriodically([this] { task_event_storage_->GcJobSummary(); }, 5 * 1000, @@ -122,12 +117,6 @@ class GcsTaskManager : public rpc::TaskInfoHandler { rpc::GetTaskEventsReply *reply, rpc::SendReplyCallback send_reply_callback) override; - /// Stops the event loop and the thread of the task event handler. - /// - /// After this is called, no more requests will be handled. - /// This function returns when the io thread is joined. - void Stop(); - /// Handler to be called when a job finishes. This marks all non-terminated tasks /// of the job as failed. /// @@ -143,11 +132,6 @@ class GcsTaskManager : public rpc::TaskInfoHandler { void OnWorkerDead(const WorkerID &worker_id, const std::shared_ptr &worker_failure_data); - /// Returns the io_service. - /// - /// \return Reference to its io_service. - instrumented_io_context &GetIoContext() { return io_service_; } - /// Return string of debug state. /// /// \return Debug string @@ -514,6 +498,9 @@ class GcsTaskManager : public rpc::TaskInfoHandler { /// Test only size_t GetNumTaskEventsStored() { return stats_counter_.Get(kNumTaskEventsStored); } + /// Dedicated IO service separated from the main service. + instrumented_io_context &io_service_; + // Mutex guarding the usage stats client absl::Mutex mutex_; @@ -526,12 +513,6 @@ class GcsTaskManager : public rpc::TaskInfoHandler { // the io_service_thread_. Access to it is *not* thread safe. std::unique_ptr task_event_storage_; - /// Its own separate IO service separated from the main service. - instrumented_io_context io_service_; - - /// Its own IO thread from the main thread. - std::unique_ptr io_service_thread_; - /// The runner to run function periodically. PeriodicalRunner periodical_runner_; diff --git a/src/ray/gcs/gcs_server/pubsub_handler.cc b/src/ray/gcs/gcs_server/pubsub_handler.cc index a110ce49b956..d926b051102c 100644 --- a/src/ray/gcs/gcs_server/pubsub_handler.cc +++ b/src/ray/gcs/gcs_server/pubsub_handler.cc @@ -22,12 +22,6 @@ InternalPubSubHandler::InternalPubSubHandler( const std::shared_ptr &gcs_publisher) : io_service_(io_service), gcs_publisher_(gcs_publisher) { RAY_CHECK(gcs_publisher_); - io_service_thread_ = std::make_unique([this] { - SetThreadName("pubsub"); - // Keep io_service_ alive. - boost::asio::io_service::work io_service_work_(io_service_); - io_service_.run(); - }); } void InternalPubSubHandler::HandleGcsPublish(rpc::GcsPublishRequest request, @@ -129,12 +123,5 @@ void InternalPubSubHandler::RemoveSubscriberFrom(const std::string &sender_id) { sender_to_subscribers_.erase(iter); } -void InternalPubSubHandler::Stop() { - io_service_.stop(); - if (io_service_thread_->joinable()) { - io_service_thread_->join(); - } -} - } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_server/pubsub_handler.h b/src/ray/gcs/gcs_server/pubsub_handler.h index 71dded16967a..a92209a6954c 100644 --- a/src/ray/gcs/gcs_server/pubsub_handler.h +++ b/src/ray/gcs/gcs_server/pubsub_handler.h @@ -47,9 +47,6 @@ class InternalPubSubHandler : public rpc::InternalPubSubHandler { rpc::GcsUnregisterSubscriberReply *reply, rpc::SendReplyCallback send_reply_callback) final; - // Stops the event loop and the thread of the pubsub handler. - void Stop(); - std::string DebugString() const; void RemoveSubscriberFrom(const std::string &sender_id); @@ -57,7 +54,6 @@ class InternalPubSubHandler : public rpc::InternalPubSubHandler { private: /// Not owning the io service, to allow sharing it with pubsub::Publisher. instrumented_io_context &io_service_; - std::unique_ptr io_service_thread_; std::shared_ptr gcs_publisher_; absl::flat_hash_map> sender_to_subscribers_; }; diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index efe775eb6bad..6de20bfe34af 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -199,8 +199,8 @@ void RedisRequestContext::RedisResponseFn(struct redisAsyncContext *async_contex }, "RedisRequestContext.Callback"); auto end_time = absl::Now(); - ray::stats::GcsLatency().Record((end_time - request_cxt->start_time_) / - absl::Milliseconds(1)); + ray::stats::GcsLatency().Record( + absl::ToDoubleMilliseconds(end_time - request_cxt->start_time_)); delete request_cxt; } } @@ -215,7 +215,7 @@ void RedisRequestContext::Run() { --pending_retries_; Status status = redis_context_->RedisAsyncCommandArgv( - *(RedisResponseFn), this, argv_.size(), argv_.data(), argc_.data()); + RedisResponseFn, this, argv_.size(), argv_.data(), argc_.data()); if (!status.ok()) { RedisResponseFn(redis_context_->GetRawRedisAsyncContext(), nullptr, this); diff --git a/src/ray/gcs/store_client/observable_store_client.cc b/src/ray/gcs/store_client/observable_store_client.cc index c188e27e57b7..147c9191a824 100644 --- a/src/ray/gcs/store_client/observable_store_client.cc +++ b/src/ray/gcs/store_client/observable_store_client.cc @@ -29,19 +29,19 @@ Status ObservableStoreClient::AsyncPut(const std::string &table_name, std::function callback) { auto start = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_count.Record(1, "Put"); - return delegate_->AsyncPut(table_name, - key, - data, - overwrite, - [start, callback = std::move(callback)](auto result) { - auto end = absl::GetCurrentTimeNanos(); - STATS_gcs_storage_operation_latency_ms.Record( - absl::Nanoseconds(end - start) / absl::Milliseconds(1), - "Put"); - if (callback) { - callback(std::move(result)); - } - }); + return delegate_->AsyncPut( + table_name, + key, + data, + overwrite, + [start, callback = std::move(callback)](auto result) { + auto end = absl::GetCurrentTimeNanos(); + STATS_gcs_storage_operation_latency_ms.Record( + absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "Put"); + if (callback) { + callback(std::move(result)); + } + }); } Status ObservableStoreClient::AsyncGet( @@ -54,7 +54,7 @@ Status ObservableStoreClient::AsyncGet( table_name, key, [start, callback](auto status, auto result) { auto end = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_latency_ms.Record( - absl::Nanoseconds(end - start) / absl::Milliseconds(1), "Get"); + absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "Get"); if (callback) { callback(status, std::move(result)); } @@ -69,7 +69,7 @@ Status ObservableStoreClient::AsyncGetAll( return delegate_->AsyncGetAll(table_name, [start, callback](auto result) { auto end = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_latency_ms.Record( - absl::Nanoseconds(end - start) / absl::Milliseconds(1), "GetAll"); + absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "GetAll"); if (callback) { callback(std::move(result)); } @@ -84,7 +84,7 @@ Status ObservableStoreClient::AsyncMultiGet( return delegate_->AsyncMultiGet(table_name, keys, [start, callback](auto result) { auto end = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_latency_ms.Record( - absl::Nanoseconds(end - start) / absl::Milliseconds(1), "MultiGet"); + absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "MultiGet"); if (callback) { callback(std::move(result)); } @@ -100,7 +100,7 @@ Status ObservableStoreClient::AsyncDelete(const std::string &table_name, table_name, key, [start, callback = std::move(callback)](auto result) { auto end = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_latency_ms.Record( - absl::Nanoseconds(end - start) / absl::Milliseconds(1), "Delete"); + absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "Delete"); if (callback) { callback(std::move(result)); } @@ -116,7 +116,7 @@ Status ObservableStoreClient::AsyncBatchDelete(const std::string &table_name, table_name, keys, [start, callback = std::move(callback)](auto result) { auto end = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_latency_ms.Record( - absl::Nanoseconds(end - start) / absl::Milliseconds(1), "BatchDelete"); + absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "BatchDelete"); if (callback) { callback(std::move(result)); } @@ -135,7 +135,7 @@ Status ObservableStoreClient::AsyncGetKeys( table_name, prefix, [start, callback = std::move(callback)](auto result) { auto end = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_latency_ms.Record( - absl::Nanoseconds(end - start) / absl::Milliseconds(1), "GetKeys"); + absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "GetKeys"); if (callback) { callback(std::move(result)); } @@ -151,7 +151,7 @@ Status ObservableStoreClient::AsyncExists(const std::string &table_name, table_name, key, [start, callback = std::move(callback)](auto result) { auto end = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_latency_ms.Record( - absl::Nanoseconds(end - start) / absl::Milliseconds(1), "Exists"); + absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "Exists"); if (callback) { callback(std::move(result)); } diff --git a/src/ray/gcs/store_client/redis_store_client.cc b/src/ray/gcs/store_client/redis_store_client.cc index 81d1ce292aa7..8a158e136376 100644 --- a/src/ray/gcs/store_client/redis_store_client.cc +++ b/src/ray/gcs/store_client/redis_store_client.cc @@ -19,9 +19,9 @@ #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "ray/common/asio/asio_util.h" #include "ray/gcs/redis_context.h" #include "ray/util/logging.h" - namespace ray { namespace gcs { From c84fd80399c8023a1063898a6ea316e0059acd83 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 18 Sep 2024 14:30:59 -0700 Subject: [PATCH 02/26] move gcs_table_storage_ back to main service. Signed-off-by: Ruiyang Wang --- src/ray/gcs/gcs_server/gcs_server.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index f9601fcae7f8..08de01b8862a 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -73,16 +73,16 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, periodical_runner_(main_service), is_started_(false), is_stopped_(false) { - // Init GCS table storage. + // Init GCS table storage. Note this is on main_service_, not kv_io_context_, to avoid + // congestion on the kv_io_context_. RAY_LOG(INFO) << "GCS storage type is " << storage_type_; switch (storage_type_) { case StorageType::IN_MEMORY: - gcs_table_storage_ = - std::make_shared(kv_io_context_.GetIoService()); + gcs_table_storage_ = std::make_shared(main_service_); break; case StorageType::REDIS_PERSIST: - gcs_table_storage_ = std::make_shared( - GetOrConnectRedis(kv_io_context_.GetIoService())); + gcs_table_storage_ = + std::make_shared(GetOrConnectRedis(main_service_)); break; default: RAY_LOG(FATAL) << "Unexpected storage type: " << storage_type_; From 36fc80862de1ba255b40387fd9d8e6e04e5d3603 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Mon, 23 Sep 2024 16:31:17 -0700 Subject: [PATCH 03/26] fix cpp test Signed-off-by: Ruiyang Wang --- .../gcs/gcs_server/test/gcs_task_manager_test.cc | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc index 07d4ee662966..939ae13be6fa 100644 --- a/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc @@ -18,6 +18,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "ray/common/asio/asio_util.h" #include "ray/gcs/pb_util.h" #include "ray/gcs/test/gcs_test_util.h" @@ -36,9 +37,15 @@ class GcsTaskManagerTest : public ::testing::Test { )"); } - virtual void SetUp() { task_manager.reset(new GcsTaskManager()); } + virtual void SetUp() { + io_context_ = std::make_unique("GcsTaskManagerTest"); + task_manager = std::make_unique(io_context_->GetIoService()); + } - virtual void TearDown() { task_manager->Stop(); } + virtual void TearDown() { + task_manager.reset(); + io_context_.reset(); + } std::vector GenTaskIDs(size_t num_tasks) { std::vector task_ids; @@ -104,7 +111,7 @@ class GcsTaskManagerTest : public ::testing::Test { request.mutable_data()->CopyFrom(events_data); // Dispatch so that it runs in GcsTaskManager's io service. - task_manager->GetIoContext().dispatch( + io_context_->GetIoService().dispatch( [this, &promise, &request, &reply]() { task_manager->HandleAddTaskEventData( request, @@ -161,7 +168,7 @@ class GcsTaskManagerTest : public ::testing::Test { request.mutable_filters()->set_exclude_driver(exclude_driver); - task_manager->GetIoContext().dispatch( + io_context_->GetIoService().dispatch( [this, &promise, &request, &reply]() { task_manager->HandleGetTaskEvents( request, @@ -275,6 +282,7 @@ class GcsTaskManagerTest : public ::testing::Test { } std::unique_ptr task_manager = nullptr; + std::unique_ptr io_context_ = nullptr; }; class GcsTaskManagerMemoryLimitedTest : public GcsTaskManagerTest { From a87c39d2810c8468c8680bfa9422b43c212384ca Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 24 Sep 2024 11:40:12 -0700 Subject: [PATCH 04/26] fix atomics now that we have multiple thread reads... Signed-off-by: Ruiyang Wang --- src/ray/gcs/gcs_server/gcs_job_manager.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.cc b/src/ray/gcs/gcs_server/gcs_job_manager.cc index b42de3f95533..4be8809b6ea6 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_job_manager.cc @@ -247,11 +247,14 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, // entrypoint script calls ray.init() multiple times). std::unordered_map> job_data_key_to_indices; - // Create a shared counter for the number of jobs processed - std::shared_ptr num_processed_jobs = std::make_shared(0); + // Create a shared counter for the number of jobs processed. + // This is written in internal_kv_'s thread and read in the main thread. + std::shared_ptr> num_processed_jobs = + std::make_shared>(0); // Create a shared boolean flag for the internal KV callback completion - std::shared_ptr kv_callback_done = std::make_shared(false); + std::shared_ptr> kv_callback_done = + std::make_shared>(false); // Function to send the reply once all jobs have been processed and KV callback // completed From 7cd77057e697b83fc6b0d2a74856b0c7492039b6 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 24 Sep 2024 12:17:50 -0700 Subject: [PATCH 05/26] atomics Signed-off-by: Ruiyang Wang --- src/ray/gcs/gcs_server/gcs_job_manager.cc | 93 +++++++++++++---------- 1 file changed, 51 insertions(+), 42 deletions(-) diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.cc b/src/ray/gcs/gcs_server/gcs_job_manager.cc index 4be8809b6ea6..ccc7df3ce6ac 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_job_manager.cc @@ -247,25 +247,6 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, // entrypoint script calls ray.init() multiple times). std::unordered_map> job_data_key_to_indices; - // Create a shared counter for the number of jobs processed. - // This is written in internal_kv_'s thread and read in the main thread. - std::shared_ptr> num_processed_jobs = - std::make_shared>(0); - - // Create a shared boolean flag for the internal KV callback completion - std::shared_ptr> kv_callback_done = - std::make_shared>(false); - - // Function to send the reply once all jobs have been processed and KV callback - // completed - auto try_send_reply = - [num_processed_jobs, kv_callback_done, reply, send_reply_callback]() { - if (*num_processed_jobs == reply->job_info_list_size() && *kv_callback_done) { - RAY_LOG(DEBUG) << "Finished getting all job info."; - GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); - } - }; - // Load the job table data into the reply. int i = 0; for (auto &data : result) { @@ -289,28 +270,59 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, job_api_data_keys.push_back(job_data_key); job_data_key_to_indices[job_data_key].push_back(i); } + i++; + } + + // Jobs are filtered. Now, optionally populate is_running_tasks and job_info. A + // `asyncio.gather` is needed but we are in C++; so we use atomic counters. - if (!request.skip_is_running_tasks_field()) { - JobID job_id = data.first; - WorkerID worker_id = - WorkerID::FromBinary(data.second.driver_address().worker_id()); + // Atomic counter of pending async tasks before sending the reply. + // Once it reaches total_tasks, the reply is sent. + std::shared_ptr> num_finished_tasks = + std::make_shared>(0); - // If job is not dead, get is_running_tasks from the core worker for the driver. - if (data.second.is_dead()) { + // N tasks for N jobs; and 1 task for the MultiKVGet. If either is skipped the counter + // still increments. + const size_t total_tasks = reply->job_info_list_size() + 1; + + // Those async tasks need to atomically read-and-increment the counter, so this + // callback can't capture the atomic variable directly. Instead, it asks for an + // regular variable argument coming from the read-and-increment caller. + auto try_send_reply = + [reply, send_reply_callback, total_tasks](size_t finished_tasks) { + if (finished_tasks == total_tasks) { + RAY_LOG(DEBUG) << "Finished getting all job info."; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); + } + }; + + if (request.skip_is_running_tasks_field()) { + // Skipping RPCs to workers, just mark all job tasks as done. + const size_t job_count = reply->job_info_list_size(); + size_t updated_finished_tasks = + num_finished_tasks->fetch_add(job_count) + job_count; + try_send_reply(updated_finished_tasks); + } else { + for (const auto &data : reply->job_info_list()) { + auto job_id = JobID::FromBinary(data.job_id()); + WorkerID worker_id = WorkerID::FromBinary(data.driver_address().worker_id()); + + // If job is dead, no need to get. + if (data.is_dead()) { reply->mutable_job_info_list(i)->set_is_running_tasks(false); core_worker_clients_.Disconnect(worker_id); - (*num_processed_jobs)++; - try_send_reply(); + size_t updated_finished_tasks = num_finished_tasks->fetch_add(1) + 1; + try_send_reply(updated_finished_tasks); } else { // Get is_running_tasks from the core worker for the driver. - auto client = core_worker_clients_.GetOrConnect(data.second.driver_address()); + auto client = core_worker_clients_.GetOrConnect(data.driver_address()); auto request = std::make_unique(); constexpr int64_t kNumPendingTasksRequestTimeoutMs = 1000; RAY_LOG(DEBUG) << "Send NumPendingTasksRequest to worker " << worker_id << ", timeout " << kNumPendingTasksRequestTimeoutMs << " ms."; client->NumPendingTasks( std::move(request), - [job_id, worker_id, reply, i, num_processed_jobs, try_send_reply]( + [job_id, worker_id, reply, i, num_finished_tasks, try_send_reply]( const Status &status, const rpc::NumPendingTasksReply &num_pending_tasks_reply) { RAY_LOG(DEBUG).WithField(worker_id) @@ -324,25 +336,25 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, bool is_running_tasks = num_pending_tasks_reply.num_pending_tasks() > 0; reply->mutable_job_info_list(i)->set_is_running_tasks(is_running_tasks); } - (*num_processed_jobs)++; - try_send_reply(); + size_t updated_finished_tasks = num_finished_tasks->fetch_add(1) + 1; + try_send_reply(updated_finished_tasks); }, kNumPendingTasksRequestTimeoutMs); } - } else { - (*num_processed_jobs)++; - try_send_reply(); } - i++; } - if (!request.skip_submission_job_info_field()) { + if (request.skip_submission_job_info_field()) { + // Skipping MultiKVGet, just mark the counter. + size_t updated_finished_tasks = num_finished_tasks->fetch_add(1) + 1; + try_send_reply(updated_finished_tasks); + } else { // Load the JobInfo for jobs submitted via the Ray Job API. auto kv_multi_get_callback = [reply, send_reply_callback, job_data_key_to_indices, - kv_callback_done, + num_finished_tasks, try_send_reply](std::unordered_map &&result) { for (const auto &data : result) { const std::string &job_data_key = data.first; @@ -365,13 +377,10 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, } } } - *kv_callback_done = true; - try_send_reply(); + size_t updated_finished_tasks = num_finished_tasks->fetch_add(1) + 1; + try_send_reply(updated_finished_tasks); }; internal_kv_.MultiGet("job", job_api_data_keys, kv_multi_get_callback); - } else { - *kv_callback_done = true; - try_send_reply(); } }; Status status = gcs_table_storage_->JobTable().GetAll(on_done); From bbf02cd4a3504d06e3b2f730128fafd24825609a Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 24 Sep 2024 16:06:50 -0700 Subject: [PATCH 06/26] fix Signed-off-by: Ruiyang Wang --- src/ray/gcs/gcs_server/gcs_job_manager.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.cc b/src/ray/gcs/gcs_server/gcs_job_manager.cc index ccc7df3ce6ac..8146295e60ab 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_job_manager.cc @@ -303,7 +303,8 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, num_finished_tasks->fetch_add(job_count) + job_count; try_send_reply(updated_finished_tasks); } else { - for (const auto &data : reply->job_info_list()) { + for (size_t i = 0; i < reply->job_info_list_size(); i++) { + const auto &data = reply->job_info_list(i); auto job_id = JobID::FromBinary(data.job_id()); WorkerID worker_id = WorkerID::FromBinary(data.driver_address().worker_id()); From 99f7ba9f698415619528b85f956536050feab1da Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 25 Sep 2024 10:23:26 -0700 Subject: [PATCH 07/26] size_t -> int for proto Signed-off-by: Ruiyang Wang --- src/ray/gcs/gcs_server/gcs_job_manager.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.cc b/src/ray/gcs/gcs_server/gcs_job_manager.cc index 8146295e60ab..d3d044c8c7fa 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_job_manager.cc @@ -303,7 +303,7 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, num_finished_tasks->fetch_add(job_count) + job_count; try_send_reply(updated_finished_tasks); } else { - for (size_t i = 0; i < reply->job_info_list_size(); i++) { + for (int i = 0; i < reply->job_info_list_size(); i++) { const auto &data = reply->job_info_list(i); auto job_id = JobID::FromBinary(data.job_id()); WorkerID worker_id = WorkerID::FromBinary(data.driver_address().worker_id()); From a1ab6c613e10d1fcd5c9058ad09a1abc5f19a86c Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 25 Sep 2024 14:19:59 -0700 Subject: [PATCH 08/26] fix atomics in periodical_runner Signed-off-by: Ruiyang Wang --- src/ray/common/asio/periodical_runner.cc | 65 ++++++++++-------------- src/ray/common/asio/periodical_runner.h | 8 +-- 2 files changed, 30 insertions(+), 43 deletions(-) diff --git a/src/ray/common/asio/periodical_runner.cc b/src/ray/common/asio/periodical_runner.cc index 7de7fafc82c6..4c1babadc293 100644 --- a/src/ray/common/asio/periodical_runner.cc +++ b/src/ray/common/asio/periodical_runner.cc @@ -20,7 +20,7 @@ namespace ray { PeriodicalRunner::PeriodicalRunner(instrumented_io_context &io_service) - : io_service_(io_service), mutex_(), stopped_(std::make_shared(false)) {} + : io_service_(io_service) {} PeriodicalRunner::~PeriodicalRunner() { RAY_LOG(DEBUG) << "PeriodicalRunner is destructed"; @@ -29,7 +29,7 @@ PeriodicalRunner::~PeriodicalRunner() { void PeriodicalRunner::Clear() { absl::MutexLock lock(&mutex_); - *stopped_ = true; + stopped_ = true; for (const auto &timer : timers_) { timer->cancel(); } @@ -38,8 +38,8 @@ void PeriodicalRunner::Clear() { void PeriodicalRunner::RunFnPeriodically(std::function fn, uint64_t period_ms, - const std::string name) { - *stopped_ = false; + const std::string &name) { + stopped_ = false; if (period_ms > 0) { auto timer = std::make_shared(io_service_); { @@ -47,13 +47,8 @@ void PeriodicalRunner::RunFnPeriodically(std::function fn, timers_.push_back(timer); } io_service_.post( - [this, - stopped = stopped_, - fn = std::move(fn), - period_ms, - name, - timer = std::move(timer)]() { - if (*stopped) { + [this, fn = std::move(fn), period_ms, name, timer = std::move(timer)]() { + if (this->stopped_) { return; } if (RayConfig::instance().event_stats()) { @@ -74,28 +69,27 @@ void PeriodicalRunner::DoRunFnPeriodically( fn(); absl::MutexLock lock(&mutex_); timer->expires_from_now(period); - timer->async_wait( - [this, stopped = stopped_, fn = std::move(fn), period, timer = std::move(timer)]( - const boost::system::error_code &error) { - if (*stopped) { - return; - } - if (error == boost::asio::error::operation_aborted) { - // `operation_aborted` is set when `timer` is canceled or destroyed. - // The Monitor lifetime may be short than the object who use it. (e.g. - // gcs_server) - return; - } - RAY_CHECK(!error) << error.message(); - DoRunFnPeriodically(fn, period, timer); - }); + timer->async_wait([this, fn, period, timer = std::move(timer)]( + const boost::system::error_code &error) { + if (stopped_) { + return; + } + if (error == boost::asio::error::operation_aborted) { + // `operation_aborted` is set when `timer` is canceled or destroyed. + // The Monitor lifetime may be short than the object who use it. (e.g. + // gcs_server) + return; + } + RAY_CHECK(!error) << error.message(); + DoRunFnPeriodically(fn, period, timer); + }); } void PeriodicalRunner::DoRunFnPeriodicallyInstrumented( const std::function &fn, boost::posix_time::milliseconds period, std::shared_ptr timer, - const std::string name) { + const std::string &name) { fn(); absl::MutexLock lock(&mutex_); timer->expires_from_now(period); @@ -104,24 +98,17 @@ void PeriodicalRunner::DoRunFnPeriodicallyInstrumented( // event loop. auto stats_handle = io_service_.stats().RecordStart(name, period.total_nanoseconds()); timer->async_wait([this, - fn = std::move(fn), - stopped = stopped_, + fn, period, timer = std::move(timer), stats_handle = std::move(stats_handle), name](const boost::system::error_code &error) { - if (*stopped) { + if (this->stopped_) { return; } io_service_.stats().RecordExecution( - [this, - stopped = stopped_, - fn = std::move(fn), - error, - period, - timer = std::move(timer), - name]() { - if (*stopped) { + [this, fn, error, period, timer, name]() { + if (this->stopped_) { return; } if (error == boost::asio::error::operation_aborted) { @@ -133,7 +120,7 @@ void PeriodicalRunner::DoRunFnPeriodicallyInstrumented( RAY_CHECK(!error) << error.message(); DoRunFnPeriodicallyInstrumented(fn, period, timer, name); }, - std::move(stats_handle)); + stats_handle); }); } diff --git a/src/ray/common/asio/periodical_runner.h b/src/ray/common/asio/periodical_runner.h index c67b02620330..c4ae74dcafa3 100644 --- a/src/ray/common/asio/periodical_runner.h +++ b/src/ray/common/asio/periodical_runner.h @@ -30,7 +30,7 @@ namespace ray { /// All registered functions will stop running once this object is destructed. class PeriodicalRunner { public: - PeriodicalRunner(instrumented_io_context &io_service); + explicit PeriodicalRunner(instrumented_io_context &io_service); ~PeriodicalRunner(); @@ -38,7 +38,7 @@ class PeriodicalRunner { void RunFnPeriodically(std::function fn, uint64_t period_ms, - const std::string name) ABSL_LOCKS_EXCLUDED(mutex_); + const std::string &name) ABSL_LOCKS_EXCLUDED(mutex_); private: void DoRunFnPeriodically(const std::function &fn, @@ -49,14 +49,14 @@ class PeriodicalRunner { void DoRunFnPeriodicallyInstrumented(const std::function &fn, boost::posix_time::milliseconds period, std::shared_ptr timer, - const std::string name) + const std::string &name) ABSL_LOCKS_EXCLUDED(mutex_); instrumented_io_context &io_service_; mutable absl::Mutex mutex_; std::vector> timers_ ABSL_GUARDED_BY(mutex_); - std::shared_ptr stopped_; + std::atomic stopped_ = false; }; } // namespace ray From cf3f343f1f6a8af75977a03adc482250ce2ecc76 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Wed, 25 Sep 2024 14:33:23 -0700 Subject: [PATCH 09/26] update doc Signed-off-by: Ruiyang Wang --- src/ray/gcs/gcs_server/gcs_job_manager.cc | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.cc b/src/ray/gcs/gcs_server/gcs_job_manager.cc index d3d044c8c7fa..7b9f8af23da8 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_job_manager.cc @@ -273,8 +273,16 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, i++; } - // Jobs are filtered. Now, optionally populate is_running_tasks and job_info. A - // `asyncio.gather` is needed but we are in C++; so we use atomic counters. + // Jobs are filtered. Now, optionally populate is_running_tasks and job_info. We + // do async calls to: + // + // - N outbound RPCs, one to each jobs' core workers on GcsServer::main_service_. + // - One InternalKV MultiGet call on GcsServer::kv_service_. + // + // And then we wait all by examining an atomic num_finished_tasks counter and then + // reply. The wait counter is written from 2 different thread, which requires an + // atomic read-and-increment. Each thread performs read-and-increment, and check + // the atomic readout to ensure try_send_reply is executed exactly once. // Atomic counter of pending async tasks before sending the reply. // Once it reaches total_tasks, the reply is sent. @@ -284,10 +292,6 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, // N tasks for N jobs; and 1 task for the MultiKVGet. If either is skipped the counter // still increments. const size_t total_tasks = reply->job_info_list_size() + 1; - - // Those async tasks need to atomically read-and-increment the counter, so this - // callback can't capture the atomic variable directly. Instead, it asks for an - // regular variable argument coming from the read-and-increment caller. auto try_send_reply = [reply, send_reply_callback, total_tasks](size_t finished_tasks) { if (finished_tasks == total_tasks) { From 6d006e9de9abc508f6f8b438c858bba4d0daf89b Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Thu, 26 Sep 2024 14:12:18 -0700 Subject: [PATCH 10/26] stopped -> shared_ptr> Signed-off-by: Ruiyang Wang --- src/ray/common/asio/periodical_runner.cc | 24 +++++++++++++++--------- src/ray/common/asio/periodical_runner.h | 4 +++- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/ray/common/asio/periodical_runner.cc b/src/ray/common/asio/periodical_runner.cc index 4c1babadc293..855590445e98 100644 --- a/src/ray/common/asio/periodical_runner.cc +++ b/src/ray/common/asio/periodical_runner.cc @@ -29,7 +29,7 @@ PeriodicalRunner::~PeriodicalRunner() { void PeriodicalRunner::Clear() { absl::MutexLock lock(&mutex_); - stopped_ = true; + *stopped_ = true; for (const auto &timer : timers_) { timer->cancel(); } @@ -39,7 +39,7 @@ void PeriodicalRunner::Clear() { void PeriodicalRunner::RunFnPeriodically(std::function fn, uint64_t period_ms, const std::string &name) { - stopped_ = false; + *stopped_ = false; if (period_ms > 0) { auto timer = std::make_shared(io_service_); { @@ -47,8 +47,13 @@ void PeriodicalRunner::RunFnPeriodically(std::function fn, timers_.push_back(timer); } io_service_.post( - [this, fn = std::move(fn), period_ms, name, timer = std::move(timer)]() { - if (this->stopped_) { + [this, + stopped = stopped_, + fn = std::move(fn), + period_ms, + name, + timer = std::move(timer)]() { + if (*stopped) { return; } if (RayConfig::instance().event_stats()) { @@ -69,9 +74,9 @@ void PeriodicalRunner::DoRunFnPeriodically( fn(); absl::MutexLock lock(&mutex_); timer->expires_from_now(period); - timer->async_wait([this, fn, period, timer = std::move(timer)]( + timer->async_wait([this, stopped = stopped_, fn, period, timer = std::move(timer)]( const boost::system::error_code &error) { - if (stopped_) { + if (*stopped) { return; } if (error == boost::asio::error::operation_aborted) { @@ -99,16 +104,17 @@ void PeriodicalRunner::DoRunFnPeriodicallyInstrumented( auto stats_handle = io_service_.stats().RecordStart(name, period.total_nanoseconds()); timer->async_wait([this, fn, + stopped = stopped_, period, timer = std::move(timer), stats_handle = std::move(stats_handle), name](const boost::system::error_code &error) { - if (this->stopped_) { + if (*stopped) { return; } io_service_.stats().RecordExecution( - [this, fn, error, period, timer, name]() { - if (this->stopped_) { + [this, stopped = stopped, fn, error, period, timer, name]() { + if (*stopped) { return; } if (error == boost::asio::error::operation_aborted) { diff --git a/src/ray/common/asio/periodical_runner.h b/src/ray/common/asio/periodical_runner.h index c4ae74dcafa3..ec469f6c1352 100644 --- a/src/ray/common/asio/periodical_runner.h +++ b/src/ray/common/asio/periodical_runner.h @@ -56,7 +56,9 @@ class PeriodicalRunner { mutable absl::Mutex mutex_; std::vector> timers_ ABSL_GUARDED_BY(mutex_); - std::atomic stopped_ = false; + // `stopped_` is copied to the timer callback, and may outlive `this`. + std::shared_ptr> stopped_ = + std::make_shared>(false); }; } // namespace ray From 110ae3ec03ffa76bf26e010b8ac40068c52b0891 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Thu, 26 Sep 2024 17:33:19 -0700 Subject: [PATCH 11/26] rename Signed-off-by: Ruiyang Wang --- src/ray/common/asio/asio_util.h | 14 +++++++------- src/ray/gcs/gcs_client/gcs_client.cc | 2 +- src/ray/gcs/gcs_server/gcs_server.h | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/ray/common/asio/asio_util.h b/src/ray/common/asio/asio_util.h index 6dc8bbe8cd8b..0dae0a972a6b 100644 --- a/src/ray/common/asio/asio_util.h +++ b/src/ray/common/asio/asio_util.h @@ -45,13 +45,13 @@ std::shared_ptr execute_after( * The constructor takes a thread name and starts the thread. * The destructor stops the io_service and joins the thread. */ -class InstrumentedIoContextWithThread { +class InstrumentedIOContextWithThread { public: /** * Constructor. * @param thread_name The name of the thread. */ - explicit InstrumentedIoContextWithThread(const std::string &thread_name) + explicit InstrumentedIOContextWithThread(const std::string &thread_name) : io_service_(), work_(io_service_) { io_thread_ = std::thread([this, thread_name] { SetThreadName(thread_name); @@ -59,14 +59,14 @@ class InstrumentedIoContextWithThread { }); } - ~InstrumentedIoContextWithThread() { Stop(); } + ~InstrumentedIOContextWithThread() { Stop(); } // Non-movable and non-copyable. - InstrumentedIoContextWithThread(const InstrumentedIoContextWithThread &) = delete; - InstrumentedIoContextWithThread &operator=(const InstrumentedIoContextWithThread &) = + InstrumentedIOContextWithThread(const InstrumentedIOContextWithThread &) = delete; + InstrumentedIOContextWithThread &operator=(const InstrumentedIOContextWithThread &) = delete; - InstrumentedIoContextWithThread(InstrumentedIoContextWithThread &&) = delete; - InstrumentedIoContextWithThread &operator=(InstrumentedIoContextWithThread &&) = delete; + InstrumentedIOContextWithThread(InstrumentedIOContextWithThread &&) = delete; + InstrumentedIOContextWithThread &operator=(InstrumentedIOContextWithThread &&) = delete; instrumented_io_context &GetIoService() { return io_service_; } diff --git a/src/ray/gcs/gcs_client/gcs_client.cc b/src/ray/gcs/gcs_client/gcs_client.cc index 8825dd0ab73f..e46007f819ba 100644 --- a/src/ray/gcs/gcs_client/gcs_client.cc +++ b/src/ray/gcs/gcs_client/gcs_client.cc @@ -719,7 +719,7 @@ std::unordered_map PythonGetNodeLabels( } Status ConnectOnSingletonIoContext(GcsClient &gcs_client, int64_t timeout_ms) { - static InstrumentedIoContextWithThread io_context("gcs_client_io_service"); + static InstrumentedIOContextWithThread io_context("gcs_client_io_service"); instrumented_io_context &io_service = io_context.GetIoService(); return gcs_client.Connect(io_service, timeout_ms); } diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index 2cf994cddb2f..59a45a95ea31 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -213,13 +213,13 @@ class GcsServer { /// The main io service to drive event posted from grpc threads. instrumented_io_context &main_service_; /// The io service used by Pubsub, for isolation from other workload. - InstrumentedIoContextWithThread pubsub_io_context_; + InstrumentedIOContextWithThread pubsub_io_context_; // The io service used by internal KV service, table storage and the StoreClient. - InstrumentedIoContextWithThread kv_io_context_; + InstrumentedIOContextWithThread kv_io_context_; // The io service used by task manager. - InstrumentedIoContextWithThread task_io_context_; + InstrumentedIOContextWithThread task_io_context_; // The io service used by ray syncer. - InstrumentedIoContextWithThread ray_syncer_io_context_; + InstrumentedIOContextWithThread ray_syncer_io_context_; /// The grpc server rpc::GrpcServer rpc_server_; /// The `ClientCallManager` object that is shared by all `NodeManagerWorkerClient`s. From 3461330a962efc7235f585f29d4235e6005aaaa9 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Thu, 26 Sep 2024 17:37:26 -0700 Subject: [PATCH 12/26] fit lint Signed-off-by: Ruiyang Wang --- src/ray/gcs/store_client/redis_store_client.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ray/gcs/store_client/redis_store_client.cc b/src/ray/gcs/store_client/redis_store_client.cc index 4836b34e31d9..0a6e53f66f48 100644 --- a/src/ray/gcs/store_client/redis_store_client.cc +++ b/src/ray/gcs/store_client/redis_store_client.cc @@ -22,10 +22,10 @@ #include "absl/cleanup/cleanup.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" -#include "ray/common/asio/asio_util.h" #include "ray/gcs/redis_context.h" #include "ray/util/container_util.h" #include "ray/util/logging.h" + namespace ray { namespace gcs { From a6075288d902f25ab55f0103689bbeafb3bf4a7f Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 29 Oct 2024 15:42:03 -0700 Subject: [PATCH 13/26] type traits and policy for kv Signed-off-by: Ruiyang Wang --- src/ray/common/asio/asio_util.h | 14 +++++++-- src/ray/gcs/gcs_server/gcs_server.cc | 14 ++++----- .../gcs_server/gcs_server_io_context_policy.h | 8 +++-- src/ray/util/type_traits.h | 31 +++++++++++++++++++ 4 files changed, 56 insertions(+), 11 deletions(-) create mode 100644 src/ray/util/type_traits.h diff --git a/src/ray/common/asio/asio_util.h b/src/ray/common/asio/asio_util.h index c7df71405758..963ac9296ddd 100644 --- a/src/ray/common/asio/asio_util.h +++ b/src/ray/common/asio/asio_util.h @@ -23,6 +23,7 @@ #include "ray/common/asio/instrumented_io_context.h" #include "ray/util/array.h" +#include "ray/util/type_traits.h" #include "ray/util/util.h" template @@ -133,15 +134,24 @@ class IOContextProvider { } } + template + struct Wrapper { + static constexpr int value = N; + }; + // Gets IOContext registered for type T. If the type is not registered in // Policy::kAllDedicatedIOContextNames, it's a compile error. template instrumented_io_context &GetIOContext() const { constexpr int index = Policy::template GetDedicatedIOContextIndex(); static_assert( - index >= -1 && index < Policy::kAllDedicatedIOContextNames.size(), + (index == -1) || + (index >= 0 && + static_cast(index) < Policy::kAllDedicatedIOContextNames.size()) || + // To show index in compile error... + ray::AlwaysFalseValue, "index out of bound, invalid GetDedicatedIOContextIndex implementation! Index " - "can only be -1 or within range of kAllDedicatedIOContextNames"); + "can only be -1 or within range of kAllDedicatedIOContextNames: "); if constexpr (index == -1) { return default_io_context_; diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 719eba869837..aed5210dd838 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -69,8 +69,8 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, periodical_runner_(io_context_provider_.GetDefaultIOContext()), is_started_(false), is_stopped_(false) { - // Init GCS table storage. Note this is on main_service_, not kv_io_context_, to avoid - // congestion on the kv_io_context_. + // Init GCS table storage. Note this is on the default io context, not the one with + // GcsInternalKVManager, to avoid congestion on the latter. RAY_LOG(INFO) << "GCS storage type is " << storage_type_; switch (storage_type_) { case StorageType::IN_MEMORY: @@ -78,8 +78,8 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, io_context_provider_.GetDefaultIOContext()); break; case StorageType::REDIS_PERSIST: - gcs_table_storage_ = - std::make_shared(GetOrConnectRedis(main_service_)); + gcs_table_storage_ = std::make_shared( + GetOrConnectRedis(io_context_provider_.GetDefaultIOContext())); break; default: RAY_LOG(FATAL) << "Unexpected storage type: " << storage_type_; @@ -560,12 +560,12 @@ void GcsServer::InitKVManager() { switch (storage_type_) { case (StorageType::REDIS_PERSIST): instance = std::make_unique(std::make_unique( - GetOrConnectRedis(kv_io_context_.GetIoService()))); + GetOrConnectRedis(io_context_provider_.GetIOContext()))); break; case (StorageType::IN_MEMORY): instance = std::make_unique( std::make_unique(std::make_unique( - io_context_provider_.GetDefaultIOContext()))); + io_context_provider_.GetIOContext()))); break; default: RAY_LOG(FATAL) << "Unexpected storage type! " << storage_type_; @@ -578,7 +578,7 @@ void GcsServer::InitKVManager() { void GcsServer::InitKVService() { RAY_CHECK(kv_manager_); kv_service_ = std::make_unique( - io_context_provider_.GetDefaultIOContext(), *kv_manager_); + io_context_provider_.GetIOContext(), *kv_manager_); // Register service. rpc_server_.RegisterService(*kv_service_, false /* token_auth */); } diff --git a/src/ray/gcs/gcs_server/gcs_server_io_context_policy.h b/src/ray/gcs/gcs_server/gcs_server_io_context_policy.h index 9c146d811353..f4e45f3e2acd 100644 --- a/src/ray/gcs/gcs_server/gcs_server_io_context_policy.h +++ b/src/ray/gcs/gcs_server/gcs_server_io_context_policy.h @@ -22,6 +22,7 @@ #include "ray/gcs/gcs_server/gcs_task_manager.h" #include "ray/gcs/pubsub/gcs_pub_sub.h" #include "ray/util/array.h" +#include "ray/util/type_traits.h" namespace ray { namespace gcs { @@ -40,10 +41,13 @@ struct GcsServerIOContextPolicy { return IndexOf("pubsub_io_context"); } else if constexpr (std::is_same_v) { return IndexOf("ray_syncer_io_context"); + } else if constexpr (std::is_same_v) { + // default io context + return -1; } else { // Due to if-constexpr limitations, this have to be in an else block. - // Using this tuple_size_v to put T into compile error message. - static_assert(std::tuple_size_v> == 0, "unknown type"); + // Using this template to put T into compile error message. + static_assert(AlwaysFalse, "unknown type"); } } diff --git a/src/ray/util/type_traits.h b/src/ray/util/type_traits.h new file mode 100644 index 000000000000..dc5e366af05f --- /dev/null +++ b/src/ray/util/type_traits.h @@ -0,0 +1,31 @@ +// Copyright 2024 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace ray { + +template +constexpr bool AlwaysFalse = false; + +template +constexpr bool AlwaysTrue = true; + +template +constexpr bool AlwaysFalseValue = false; + +template +constexpr bool AlwaysTrueValue = true; + +} // namespace ray From 698cbe9c187a47c194e6cc424deb2d934525a8c2 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 29 Oct 2024 15:51:39 -0700 Subject: [PATCH 14/26] remove temp code Signed-off-by: Ruiyang Wang --- src/ray/common/asio/asio_util.h | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/ray/common/asio/asio_util.h b/src/ray/common/asio/asio_util.h index 963ac9296ddd..4e8beea50f3e 100644 --- a/src/ray/common/asio/asio_util.h +++ b/src/ray/common/asio/asio_util.h @@ -134,11 +134,6 @@ class IOContextProvider { } } - template - struct Wrapper { - static constexpr int value = N; - }; - // Gets IOContext registered for type T. If the type is not registered in // Policy::kAllDedicatedIOContextNames, it's a compile error. template @@ -151,7 +146,7 @@ class IOContextProvider { // To show index in compile error... ray::AlwaysFalseValue, "index out of bound, invalid GetDedicatedIOContextIndex implementation! Index " - "can only be -1 or within range of kAllDedicatedIOContextNames: "); + "can only be -1 or within range of kAllDedicatedIOContextNames"); if constexpr (index == -1) { return default_io_context_; From a5e3d571b9babf551687b4bbe8c2e8ad97c36745 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Mon, 4 Nov 2024 09:57:56 -0800 Subject: [PATCH 15/26] fix GetOrConnectRedis Signed-off-by: Ruiyang Wang --- src/ray/gcs/gcs_server/gcs_server.cc | 31 +++++++++---------- src/ray/gcs/gcs_server/gcs_server.h | 4 +-- .../gcs/store_client/redis_store_client.cc | 1 - 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index aed5210dd838..3e3b641f7aea 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -78,8 +78,14 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, io_context_provider_.GetDefaultIOContext()); break; case StorageType::REDIS_PERSIST: - gcs_table_storage_ = std::make_shared( - GetOrConnectRedis(io_context_provider_.GetDefaultIOContext())); + auto redis_client = CreateRedisClient(io_context_provider_.GetDefaultIOContext()); + gcs_table_storage_ = std::make_shared(redis_client); + // Init redis failure detector. + gcs_redis_failure_detector_ = std::make_shared( + io_context_provider_.GetDefaultIOContext(), redis_client, []() { + RAY_LOG(FATAL) << "Redis connection failed. Shutdown GCS."; + }); + gcs_redis_failure_detector_->Start(); break; default: RAY_LOG(FATAL) << "Unexpected storage type: " << storage_type_; @@ -560,7 +566,7 @@ void GcsServer::InitKVManager() { switch (storage_type_) { case (StorageType::REDIS_PERSIST): instance = std::make_unique(std::make_unique( - GetOrConnectRedis(io_context_provider_.GetIOContext()))); + CreateRedisClient(io_context_provider_.GetIOContext()))); break; case (StorageType::IN_MEMORY): instance = std::make_unique( @@ -821,21 +827,12 @@ std::string GcsServer::GetDebugState() const { return stream.str(); } -std::shared_ptr GcsServer::GetOrConnectRedis( +std::shared_ptr GcsServer::CreateRedisClient( instrumented_io_context &io_service) { - if (redis_client_ == nullptr) { - redis_client_ = std::make_shared(GetRedisClientOptions()); - auto status = redis_client_->Connect(io_context_provider_.GetDefaultIOContext()); - RAY_CHECK(status.ok()) << "Failed to init redis gcs client as " << status; - - // Init redis failure detector. - gcs_redis_failure_detector_ = std::make_shared( - io_context_provider_.GetDefaultIOContext(), redis_client_, []() { - RAY_LOG(FATAL) << "Redis connection failed. Shutdown GCS."; - }); - gcs_redis_failure_detector_->Start(); - } - return redis_client_; + auto redis_client = std::make_shared(GetRedisClientOptions()); + auto status = redis_client->Connect(io_service); + RAY_CHECK(status.ok()) << "Failed to init redis gcs client as " << status; + return redis_client; } void GcsServer::PrintAsioStats() { diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index d675d53e6263..95358ee712c2 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -203,7 +203,7 @@ class GcsServer { void PrintAsioStats(); /// Get or connect to a redis server - std::shared_ptr GetOrConnectRedis(instrumented_io_context &io_service); + std::shared_ptr CreateRedisClient(instrumented_io_context &io_service); void TryGlobalGC(); @@ -281,8 +281,6 @@ class GcsServer { std::unique_ptr task_info_service_; /// Gcs Autoscaler state manager. std::unique_ptr autoscaler_state_service_; - /// Backend client. - std::shared_ptr redis_client_; /// A publisher for publishing gcs messages. std::shared_ptr gcs_publisher_; /// Grpc based pubsub's periodical runner. diff --git a/src/ray/gcs/store_client/redis_store_client.cc b/src/ray/gcs/store_client/redis_store_client.cc index 0a6e53f66f48..08a84921d566 100644 --- a/src/ray/gcs/store_client/redis_store_client.cc +++ b/src/ray/gcs/store_client/redis_store_client.cc @@ -325,7 +325,6 @@ Status RedisStoreClient::DeleteByKeys(const std::string &table, auto total_count = del_cmds.size(); auto finished_count = std::make_shared(0); auto num_deleted = std::make_shared(0); - auto context = redis_client_->GetPrimaryContext(); for (auto &command : del_cmds) { // `callback` is copied to each `delete_callback` lambda. Don't move. auto delete_callback = [num_deleted, finished_count, total_count, callback]( From 265f30584252b5cf0430ddbbcd40f5804b1b4d30 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 5 Nov 2024 13:40:01 -0800 Subject: [PATCH 16/26] Dispatchable class and InternalKVInterface Signed-off-by: Ruiyang Wang --- python/ray/includes/global_state_accessor.pxd | 4 +- src/ray/common/BUILD | 1 + src/ray/common/asio/dispatchable.h | 71 +++++++++ .../common/asio/instrumented_io_context.cc | 4 +- src/ray/common/asio/instrumented_io_context.h | 4 +- src/ray/common/runtime_env_manager.h | 3 +- src/ray/gcs/gcs_server/gcs_function_manager.h | 12 +- src/ray/gcs/gcs_server/gcs_job_manager.cc | 3 +- src/ray/gcs/gcs_server/gcs_job_manager.h | 3 + src/ray/gcs/gcs_server/gcs_kv_manager.cc | 15 +- src/ray/gcs/gcs_server/gcs_kv_manager.h | 21 ++- .../gcs_server/gcs_redis_failure_detector.cc | 6 +- src/ray/gcs/gcs_server/gcs_server.cc | 144 ++++++++++-------- src/ray/gcs/gcs_server/gcs_table_storage.h | 3 +- src/ray/gcs/gcs_server/store_client_kv.cc | 76 +++++---- src/ray/gcs/gcs_server/store_client_kv.h | 19 +-- src/ray/gcs/gcs_server/usage_stats_client.cc | 18 ++- src/ray/gcs/gcs_server/usage_stats_client.h | 4 +- src/ray/gcs/redis_context.cc | 5 +- src/ray/gcs/redis_context.h | 13 +- 20 files changed, 276 insertions(+), 153 deletions(-) create mode 100644 src/ray/common/asio/dispatchable.h diff --git a/python/ray/includes/global_state_accessor.pxd b/python/ray/includes/global_state_accessor.pxd index a38db9fb0403..481cab9af20c 100644 --- a/python/ray/includes/global_state_accessor.pxd +++ b/python/ray/includes/global_state_accessor.pxd @@ -101,7 +101,7 @@ cdef extern from * namespace "ray::gcs" nogil: std::make_unique(std::move(redis_client))); bool ret_val = false; - cli->Get("session", key, [&](std::optional result) { + cli->Get("session", key, {[&](std::optional result) { if (result.has_value()) { *data = result.value(); ret_val = true; @@ -110,7 +110,7 @@ cdef extern from * namespace "ray::gcs" nogil: << " from persistent storage."; ret_val = false; } - }); + }, io_service}); io_service.run_for(std::chrono::milliseconds(1000)); return ret_val; diff --git a/src/ray/common/BUILD b/src/ray/common/BUILD index 4358a2a9470f..3d8c118eea34 100644 --- a/src/ray/common/BUILD +++ b/src/ray/common/BUILD @@ -196,6 +196,7 @@ ray_cc_library( hdrs = [ "asio/asio_chaos.h", "asio/asio_util.h", + "asio/dispatchable.h", "asio/instrumented_io_context.h", "asio/io_service_pool.h", "asio/periodical_runner.h", diff --git a/src/ray/common/asio/dispatchable.h b/src/ray/common/asio/dispatchable.h new file mode 100644 index 000000000000..71d812efeee6 --- /dev/null +++ b/src/ray/common/asio/dispatchable.h @@ -0,0 +1,71 @@ +#pragma once +#include +#include + +#include "ray/common/asio/instrumented_io_context.h" + +namespace ray { + +// Wrapper for a std::function that includes an instrumented_io_context for dispatching. +// On `Dispatch`, the function is called with the provided arguments, dispatched onto the +// provided io_context. +// +// The io_context must outlive the Dispatchable object. +template +class Dispatchable { + public: + Dispatchable(std::function func, instrumented_io_context &io_context) + : func_(std::move(func)), io_context_(&io_context) {} + + explicit Dispatchable(nullptr_t) : func_(nullptr), io_context_(nullptr) {} + + template + void DispatchIfNonNull(const std::string &name, Args &&...args) const { + if (func_) { + RAY_CHECK(io_context_ != nullptr); + io_context_->dispatch( + [func = func_, + args_tuple = std::make_tuple(std::forward(args)...)]() mutable { + std::apply(func, std::move(args_tuple)); + }, + name); + } + } + + std::function AsDispatchedFunction(const std::string &name) const & { + auto copied = *this; + return [copied, name](auto &&...args) { + copied.DispatchIfNonNull(name, std::forward(args)...); + }; + } + + std::function AsDispatchedFunction(const std::string &name) && { + return [moved = std::move(*this), name](auto &&...args) { + moved.DispatchIfNonNull(name, std::forward(args)...); + }; + } + + bool operator==(std::nullptr_t) const { return func_ == nullptr; } + bool operator!=(std::nullptr_t) const { return func_ != nullptr; } + + std::function func_; + instrumented_io_context *const io_context_; +}; + +namespace internal { + +template +struct ToDispatchableHelper; + +template +struct ToDispatchableHelper> { + using type = Dispatchable; +}; + +} // namespace internal + +// using ToDispatchable> = Dispatchable; +template +using ToDispatchable = typename internal::ToDispatchableHelper::type; + +} // namespace ray diff --git a/src/ray/common/asio/instrumented_io_context.cc b/src/ray/common/asio/instrumented_io_context.cc index 30c3b288ef16..665bf5eec762 100644 --- a/src/ray/common/asio/instrumented_io_context.cc +++ b/src/ray/common/asio/instrumented_io_context.cc @@ -24,7 +24,7 @@ #include "ray/common/asio/asio_util.h" void instrumented_io_context::post(std::function handler, - const std::string name, + const std::string &name, int64_t delay_us) { if (RayConfig::instance().event_stats()) { // References are only invalidated upon deletion of the corresponding item from the @@ -47,7 +47,7 @@ void instrumented_io_context::post(std::function handler, } void instrumented_io_context::dispatch(std::function handler, - const std::string name) { + const std::string &name) { if (!RayConfig::instance().event_stats()) { return boost::asio::io_context::post(std::move(handler)); } diff --git a/src/ray/common/asio/instrumented_io_context.h b/src/ray/common/asio/instrumented_io_context.h index 088fa557d3ce..3d3424ec51ce 100644 --- a/src/ray/common/asio/instrumented_io_context.h +++ b/src/ray/common/asio/instrumented_io_context.h @@ -37,7 +37,7 @@ class instrumented_io_context : public boost::asio::io_context { /// \param name A human-readable name for the handler, to be used for viewing stats /// for the provided handler. /// \param delay_us Delay time before the handler will be executed. - void post(std::function handler, const std::string name, int64_t delay_us = 0); + void post(std::function handler, const std::string &name, int64_t delay_us = 0); /// A proxy post function that collects count, queueing, and execution statistics for /// the given handler. @@ -45,7 +45,7 @@ class instrumented_io_context : public boost::asio::io_context { /// \param handler The handler to be posted to the event loop. /// \param name A human-readable name for the handler, to be used for viewing stats /// for the provided handler. - void dispatch(std::function handler, const std::string name); + void dispatch(std::function handler, const std::string &name); EventTracker &stats() const { return *event_stats_; }; diff --git a/src/ray/common/runtime_env_manager.h b/src/ray/common/runtime_env_manager.h index a6b282863307..df50850a3dfb 100644 --- a/src/ray/common/runtime_env_manager.h +++ b/src/ray/common/runtime_env_manager.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once #include +#include #include "absl/container/flat_hash_map.h" #include "ray/common/id.h" @@ -32,7 +33,7 @@ class RuntimeEnvManager { public: using DeleteFunc = std::function)>; - explicit RuntimeEnvManager(DeleteFunc deleter) : deleter_(deleter) {} + explicit RuntimeEnvManager(DeleteFunc deleter) : deleter_(std::move(deleter)) {} /// Increase the reference of URI by job or actor ID and runtime_env. /// diff --git a/src/ray/gcs/gcs_server/gcs_function_manager.h b/src/ray/gcs/gcs_server/gcs_function_manager.h index 4e409941d926..fe6b13c070ba 100644 --- a/src/ray/gcs/gcs_server/gcs_function_manager.h +++ b/src/ray/gcs/gcs_server/gcs_function_manager.h @@ -46,12 +46,18 @@ class GcsFunctionManager { private: void RemoveExportedFunctions(const JobID &job_id) { auto job_id_hex = job_id.Hex(); - kv_.Del("fun", "RemoteFunction:" + job_id_hex + ":", true, nullptr); - kv_.Del("fun", "ActorClass:" + job_id_hex + ":", true, nullptr); + kv_.Del("fun", + "RemoteFunction:" + job_id_hex + ":", + true, + Dispatchable{nullptr}); + kv_.Del("fun", + "ActorClass:" + job_id_hex + ":", + true, + Dispatchable{nullptr}); kv_.Del("fun", std::string(kWorkerSetupHookKeyName) + ":" + job_id_hex + ":", true, - nullptr); + Dispatchable{nullptr}); } // Handler for internal KV diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.cc b/src/ray/gcs/gcs_server/gcs_job_manager.cc index f68a764f600c..32b34c4b9745 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_job_manager.cc @@ -420,7 +420,8 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, size_t updated_finished_tasks = num_finished_tasks->fetch_add(1) + 1; try_send_reply(updated_finished_tasks); }; - internal_kv_.MultiGet("job", job_api_data_keys, kv_multi_get_callback); + internal_kv_.MultiGet( + "job", job_api_data_keys, {std::move(kv_multi_get_callback), io_context_}); } }; Status status = gcs_table_storage_->JobTable().GetAll(on_done); diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.h b/src/ray/gcs/gcs_server/gcs_job_manager.h index 95f43c7e27ad..c1d5055dd8a5 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.h +++ b/src/ray/gcs/gcs_server/gcs_job_manager.h @@ -55,12 +55,14 @@ class GcsJobManager : public rpc::JobInfoHandler { RuntimeEnvManager &runtime_env_manager, GcsFunctionManager &function_manager, InternalKVInterface &internal_kv, + instrumented_io_context &io_context, rpc::ClientFactoryFn client_factory = nullptr) : gcs_table_storage_(std::move(gcs_table_storage)), gcs_publisher_(std::move(gcs_publisher)), runtime_env_manager_(runtime_env_manager), function_manager_(function_manager), internal_kv_(internal_kv), + io_context_(io_context), core_worker_clients_(client_factory) {} void Initialize(const GcsInitData &gcs_init_data); @@ -130,6 +132,7 @@ class GcsJobManager : public rpc::JobInfoHandler { ray::RuntimeEnvManager &runtime_env_manager_; GcsFunctionManager &function_manager_; InternalKVInterface &internal_kv_; + instrumented_io_context &io_context_; /// The cached core worker clients which are used to communicate with workers. rpc::CoreWorkerClientPool core_worker_clients_; diff --git a/src/ray/gcs/gcs_server/gcs_kv_manager.cc b/src/ray/gcs/gcs_server/gcs_kv_manager.cc index 47f7146710d2..90aa4dc4d49d 100644 --- a/src/ray/gcs/gcs_server/gcs_kv_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_kv_manager.cc @@ -39,7 +39,8 @@ void GcsInternalKVManager::HandleInternalKVGet( send_reply_callback, reply, Status::NotFound("Failed to find the key")); } }; - kv_instance_->Get(request.namespace_(), request.key(), std::move(callback)); + kv_instance_->Get( + request.namespace_(), request.key(), {std::move(callback), io_context_}); } } @@ -64,7 +65,7 @@ void GcsInternalKVManager::HandleInternalKVMultiGet( GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); }; std::vector keys(request.keys().begin(), request.keys().end()); - kv_instance_->MultiGet(request.namespace_(), keys, std::move(callback)); + kv_instance_->MultiGet(request.namespace_(), keys, {std::move(callback), io_context_}); } void GcsInternalKVManager::HandleInternalKVPut( @@ -83,7 +84,7 @@ void GcsInternalKVManager::HandleInternalKVPut( request.key(), request.value(), request.overwrite(), - std::move(callback)); + {std::move(callback), io_context_}); } } @@ -102,7 +103,7 @@ void GcsInternalKVManager::HandleInternalKVDel( kv_instance_->Del(request.namespace_(), request.key(), request.del_by_prefix(), - std::move(callback)); + {std::move(callback), io_context_}); } } @@ -119,7 +120,8 @@ void GcsInternalKVManager::HandleInternalKVExists( reply->set_exists(exists); GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); }; - kv_instance_->Exists(request.namespace_(), request.key(), std::move(callback)); + kv_instance_->Exists( + request.namespace_(), request.key(), {std::move(callback), io_context_}); } } @@ -137,7 +139,8 @@ void GcsInternalKVManager::HandleInternalKVKeys( } GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); }; - kv_instance_->Keys(request.namespace_(), request.prefix(), std::move(callback)); + kv_instance_->Keys( + request.namespace_(), request.prefix(), {std::move(callback), io_context_}); } } diff --git a/src/ray/gcs/gcs_server/gcs_kv_manager.h b/src/ray/gcs/gcs_server/gcs_kv_manager.h index 9b8830102a91..d6b24ecaea21 100644 --- a/src/ray/gcs/gcs_server/gcs_kv_manager.h +++ b/src/ray/gcs/gcs_server/gcs_kv_manager.h @@ -17,6 +17,7 @@ #include "absl/container/btree_map.h" #include "absl/synchronization/mutex.h" +#include "ray/common/asio/dispatchable.h" #include "ray/gcs/redis_client.h" #include "ray/gcs/store_client/redis_store_client.h" #include "ray/rpc/gcs_server/gcs_rpc_server.h" @@ -37,7 +38,7 @@ class InternalKVInterface { /// \param callback Returns the value or null if the key doesn't exist. virtual void Get(const std::string &ns, const std::string &key, - std::function)> callback) = 0; + Dispatchable)> callback) = 0; /// Get the values associated with `keys`. /// @@ -47,7 +48,7 @@ class InternalKVInterface { virtual void MultiGet( const std::string &ns, const std::vector &keys, - std::function)> callback) = 0; + Dispatchable)> callback) = 0; /// Associate a key with the specified value. /// @@ -62,7 +63,7 @@ class InternalKVInterface { const std::string &key, const std::string &value, bool overwrite, - std::function callback) = 0; + Dispatchable callback) = 0; /// Delete the key from the store. /// @@ -74,7 +75,7 @@ class InternalKVInterface { virtual void Del(const std::string &ns, const std::string &key, bool del_by_prefix, - std::function callback) = 0; + Dispatchable callback) = 0; /// Check whether the key exists in the store. /// @@ -83,7 +84,7 @@ class InternalKVInterface { /// \param callback Callback function. virtual void Exists(const std::string &ns, const std::string &key, - std::function callback) = 0; + Dispatchable callback) = 0; /// Get the keys for a given prefix. /// @@ -92,7 +93,7 @@ class InternalKVInterface { /// \param callback return all the keys matching the prefix. virtual void Keys(const std::string &ns, const std::string &prefix, - std::function)> callback) = 0; + Dispatchable)> callback) = 0; virtual ~InternalKVInterface() = default; }; @@ -101,8 +102,11 @@ class InternalKVInterface { class GcsInternalKVManager : public rpc::InternalKVHandler { public: explicit GcsInternalKVManager(std::unique_ptr kv_instance, - const std::string &raylet_config_list) - : kv_instance_(std::move(kv_instance)), raylet_config_list_(raylet_config_list) {} + const std::string &raylet_config_list, + instrumented_io_context &io_context) + : kv_instance_(std::move(kv_instance)), + raylet_config_list_(raylet_config_list), + io_context_(io_context) {} void HandleInternalKVGet(rpc::InternalKVGetRequest request, rpc::InternalKVGetReply *reply, @@ -138,6 +142,7 @@ class GcsInternalKVManager : public rpc::InternalKVHandler { private: std::unique_ptr kv_instance_; const std::string raylet_config_list_; + instrumented_io_context &io_context_; Status ValidateKey(const std::string &key) const; }; diff --git a/src/ray/gcs/gcs_server/gcs_redis_failure_detector.cc b/src/ray/gcs/gcs_server/gcs_redis_failure_detector.cc index 6a601a610255..9ea8d46b90c7 100644 --- a/src/ray/gcs/gcs_server/gcs_redis_failure_detector.cc +++ b/src/ray/gcs/gcs_server/gcs_redis_failure_detector.cc @@ -14,6 +14,8 @@ #include "ray/gcs/gcs_server/gcs_redis_failure_detector.h" +#include + #include "ray/common/ray_config.h" namespace ray { @@ -24,7 +26,7 @@ GcsRedisFailureDetector::GcsRedisFailureDetector( std::shared_ptr redis_client, std::function callback) : io_service_(io_service), - redis_client_(redis_client), + redis_client_(std::move(redis_client)), callback_(std::move(callback)) {} void GcsRedisFailureDetector::Start() { @@ -45,7 +47,7 @@ void GcsRedisFailureDetector::DetectRedis() { auto redis_callback = [this](const std::shared_ptr &reply) { if (reply->IsNil()) { RAY_LOG(ERROR) << "Redis is inactive."; - callback_(); + this->io_service_.dispatch(this->callback_, "GcsRedisFailureDetector.DetectRedis"); } }; auto cxt = redis_client_->GetPrimaryContext(); diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 3e3b641f7aea..5ef57e905f95 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -15,6 +15,7 @@ #include "ray/gcs/gcs_server/gcs_server.h" #include +#include #include "ray/common/asio/asio_util.h" #include "ray/common/asio/instrumented_io_context.h" @@ -72,21 +73,23 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, // Init GCS table storage. Note this is on the default io context, not the one with // GcsInternalKVManager, to avoid congestion on the latter. RAY_LOG(INFO) << "GCS storage type is " << storage_type_; + auto &io_context = io_context_provider_.GetDefaultIOContext(); switch (storage_type_) { case StorageType::IN_MEMORY: - gcs_table_storage_ = std::make_shared( - io_context_provider_.GetDefaultIOContext()); + gcs_table_storage_ = std::make_shared(io_context); break; - case StorageType::REDIS_PERSIST: - auto redis_client = CreateRedisClient(io_context_provider_.GetDefaultIOContext()); - gcs_table_storage_ = std::make_shared(redis_client); + case StorageType::REDIS_PERSIST: { + auto redis_client = CreateRedisClient(io_context); + gcs_table_storage_ = + std::make_shared(redis_client, io_context); // Init redis failure detector. - gcs_redis_failure_detector_ = std::make_shared( - io_context_provider_.GetDefaultIOContext(), redis_client, []() { + gcs_redis_failure_detector_ = + std::make_shared(io_context, redis_client, []() { RAY_LOG(FATAL) << "Redis connection failed. Shutdown GCS."; }); gcs_redis_failure_detector_->Start(); break; + } default: RAY_LOG(FATAL) << "Unexpected storage type: " << storage_type_; } @@ -160,31 +163,34 @@ void GcsServer::Start() { void GcsServer::GetOrGenerateClusterId( std::function &&continuation) { static std::string const kTokenNamespace = "cluster"; + auto &io_context = io_context_provider_.GetIOContext(); kv_manager_->GetInstance().Get( kTokenNamespace, kClusterIdKey, - [this, continuation = std::move(continuation)]( - std::optional provided_cluster_id) mutable { - if (!provided_cluster_id.has_value()) { - ClusterID cluster_id = ClusterID::FromRandom(); - RAY_LOG(INFO) << "No existing server cluster ID found. Generating new ID: " - << cluster_id.Hex(); - kv_manager_->GetInstance().Put( - kTokenNamespace, - kClusterIdKey, - cluster_id.Binary(), - false, - [cluster_id, - continuation = std::move(continuation)](bool added_entry) mutable { - RAY_CHECK(added_entry) << "Failed to persist new cluster ID!"; - continuation(cluster_id); - }); - } else { - ClusterID cluster_id = ClusterID::FromBinary(provided_cluster_id.value()); - RAY_LOG(INFO) << "Found existing server token: " << cluster_id; - continuation(cluster_id); - } - }); + {[this, &io_context, continuation = std::move(continuation)]( + std::optional provided_cluster_id) mutable { + if (!provided_cluster_id.has_value()) { + ClusterID cluster_id = ClusterID::FromRandom(); + RAY_LOG(INFO) << "No existing server cluster ID found. Generating new ID: " + << cluster_id.Hex(); + kv_manager_->GetInstance().Put( + kTokenNamespace, + kClusterIdKey, + cluster_id.Binary(), + false, + {[cluster_id, + continuation = std::move(continuation)](bool added_entry) mutable { + RAY_CHECK(added_entry) << "Failed to persist new cluster ID!"; + continuation(cluster_id); + }, + io_context}); + } else { + ClusterID cluster_id = ClusterID::FromBinary(provided_cluster_id.value()); + RAY_LOG(INFO) << "Found existing server token: " << cluster_id; + continuation(cluster_id); + } + }, + io_context}); } void GcsServer::DoStart(const GcsInitData &gcs_init_data) { @@ -422,12 +428,14 @@ void GcsServer::InitGcsJobManager(const GcsInitData &gcs_init_data) { return std::make_shared(address, client_call_manager_); }; RAY_CHECK(gcs_table_storage_ && gcs_publisher_); - gcs_job_manager_ = std::make_unique(gcs_table_storage_, - gcs_publisher_, - *runtime_env_manager_, - *function_manager_, - kv_manager_->GetInstance(), - client_factory); + gcs_job_manager_ = + std::make_unique(gcs_table_storage_, + gcs_publisher_, + *runtime_env_manager_, + *function_manager_, + kv_manager_->GetInstance(), + io_context_provider_.GetDefaultIOContext(), + client_factory); gcs_job_manager_->Initialize(gcs_init_data); // Register service. @@ -552,7 +560,8 @@ void GcsServer::InitFunctionManager() { } void GcsServer::InitUsageStatsClient() { - usage_stats_client_ = std::make_unique(kv_manager_->GetInstance()); + usage_stats_client_ = std::make_unique( + kv_manager_->GetInstance(), io_context_provider_.GetDefaultIOContext()); gcs_worker_manager_->SetUsageStatsClient(usage_stats_client_.get()); gcs_actor_manager_->SetUsageStatsClient(usage_stats_client_.get()); @@ -563,22 +572,23 @@ void GcsServer::InitUsageStatsClient() { void GcsServer::InitKVManager() { // TODO (yic): Use a factory with configs std::unique_ptr instance; + auto &io_context = io_context_provider_.GetIOContext(); switch (storage_type_) { case (StorageType::REDIS_PERSIST): - instance = std::make_unique(std::make_unique( - CreateRedisClient(io_context_provider_.GetIOContext()))); + instance = std::make_unique( + std::make_unique(CreateRedisClient(io_context))); break; case (StorageType::IN_MEMORY): - instance = std::make_unique( - std::make_unique(std::make_unique( - io_context_provider_.GetIOContext()))); + instance = + std::make_unique(std::make_unique( + std::make_unique(io_context))); break; default: RAY_LOG(FATAL) << "Unexpected storage type! " << storage_type_; } - kv_manager_ = std::make_unique(std::move(instance), - config_.raylet_config_list); + kv_manager_ = std::make_unique( + std::move(instance), config_.raylet_config_list, io_context); } void GcsServer::InitKVService() { @@ -599,8 +609,9 @@ void GcsServer::InitPubSubHandler() { } void GcsServer::InitRuntimeEnvManager() { + auto &io_context = io_context_provider_.GetDefaultIOContext(); runtime_env_manager_ = std::make_unique( - /*deleter=*/[this](const std::string &plugin_uri, auto callback) { + /*deleter=*/[this, &io_context](const std::string &plugin_uri, auto callback) { // A valid runtime env URI is of the form "protocol://hash". std::string protocol_sep = "://"; auto protocol_end_pos = plugin_uri.find(protocol_sep); @@ -619,7 +630,8 @@ void GcsServer::InitRuntimeEnvManager() { "" /* namespace */, plugin_uri /* key */, false /* del_by_prefix*/, - [callback = std::move(callback)](int64_t) { callback(false); }); + {[callback = std::move(callback)](int64_t) { callback(false); }, + io_context}); } } }); @@ -628,7 +640,7 @@ void GcsServer::InitRuntimeEnvManager() { *runtime_env_manager_, /*delay_executor=*/ [this](std::function task, uint32_t delay_ms) { return execute_after(io_context_provider_.GetDefaultIOContext(), - task, + std::move(task), std::chrono::milliseconds(delay_ms)); }); runtime_env_service_ = std::make_unique( @@ -651,28 +663,34 @@ void GcsServer::InitGcsAutoscalerStateManager(const GcsInitData &gcs_init_data) auto v2_enabled = std::to_string(RayConfig::instance().enable_autoscaler_v2()); RAY_LOG(INFO) << "Autoscaler V2 enabled: " << v2_enabled; + auto &io_context = io_context_provider_.GetDefaultIOContext(); + kv_manager_->GetInstance().Put( kGcsAutoscalerStateNamespace, kGcsAutoscalerV2EnabledKey, v2_enabled, /*overwrite=*/true, - [this, v2_enabled](bool new_value_put) { - if (!new_value_put) { - // NOTE(rickyx): We cannot know if an overwirte Put succeeds or fails (e.g. when - // GCS re-started), so we just try to get the value to check if it's correct. - // TODO(rickyx): We could probably load some system configs from internal kv - // when we initialize GCS from restart to avoid this. - kv_manager_->GetInstance().Get( - kGcsAutoscalerStateNamespace, - kGcsAutoscalerV2EnabledKey, - [v2_enabled](std::optional value) { - RAY_CHECK(value.has_value()) << "Autoscaler v2 feature flag wasn't found " - "in GCS, this is unexpected."; - RAY_CHECK(*value == v2_enabled) << "Autoscaler v2 feature flag in GCS " - "doesn't match the one we put."; - }); - } - }); + {[this, v2_enabled, &io_context](bool new_value_put) { + if (!new_value_put) { + // NOTE(rickyx): We cannot know if an overwirte Put succeeds or fails (e.g. + // when GCS re-started), so we just try to get the value to check if it's + // correct. + // TODO(rickyx): We could probably load some system configs from internal kv + // when we initialize GCS from restart to avoid this. + kv_manager_->GetInstance().Get( + kGcsAutoscalerStateNamespace, + kGcsAutoscalerV2EnabledKey, + {[v2_enabled](std::optional value) { + RAY_CHECK(value.has_value()) + << "Autoscaler v2 feature flag wasn't found " + "in GCS, this is unexpected."; + RAY_CHECK(*value == v2_enabled) << "Autoscaler v2 feature flag in GCS " + "doesn't match the one we put."; + }, + io_context}); + } + }, + io_context}); gcs_autoscaler_state_manager_ = std::make_unique(config_.session_name, diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.h b/src/ray/gcs/gcs_server/gcs_table_storage.h index 16133f290138..bc5cd581d1aa 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.h +++ b/src/ray/gcs/gcs_server/gcs_table_storage.h @@ -292,7 +292,8 @@ class GcsTableStorage { /// that uses redis as storage. class RedisGcsTableStorage : public GcsTableStorage { public: - explicit RedisGcsTableStorage(std::shared_ptr redis_client) + explicit RedisGcsTableStorage(std::shared_ptr redis_client, + instrumented_io_context &io_context) : GcsTableStorage(std::make_shared(std::move(redis_client))) {} }; diff --git a/src/ray/gcs/gcs_server/store_client_kv.cc b/src/ray/gcs/gcs_server/store_client_kv.cc index 17228bff6f86..3cfab04e5975 100644 --- a/src/ray/gcs/gcs_server/store_client_kv.cc +++ b/src/ray/gcs/gcs_server/store_client_kv.cc @@ -51,29 +51,24 @@ StoreClientInternalKV::StoreClientInternalKV(std::unique_ptr store_ : delegate_(std::move(store_client)), table_name_(TablePrefix_Name(TablePrefix::KV)) {} -void StoreClientInternalKV::Get( - const std::string &ns, - const std::string &key, - std::function)> callback) { - if (!callback) { - callback = [](auto) {}; - } +void StoreClientInternalKV::Get(const std::string &ns, + const std::string &key, + Dispatchable)> callback) { RAY_CHECK_OK(delegate_->AsyncGet( table_name_, MakeKey(ns, key), [callback = std::move(callback)](auto status, auto result) { - callback(result.has_value() ? std::optional(result.value()) - : std::optional()); + callback.DispatchIfNonNull("StoreClientInternalKV::Get", + result.has_value() + ? std::optional(result.value()) + : std::optional()); })); } void StoreClientInternalKV::MultiGet( const std::string &ns, const std::vector &keys, - std::function)> callback) { - if (!callback) { - callback = [](auto) {}; - } + Dispatchable)> callback) { std::vector prefixed_keys; prefixed_keys.reserve(keys.size()); for (const auto &key : keys) { @@ -85,7 +80,7 @@ void StoreClientInternalKV::MultiGet( for (const auto &item : result) { ret.emplace(ExtractKey(item.first), item.second); } - callback(std::move(ret)); + callback.DispatchIfNonNull("StoreClientInternalKV::MultiGet", std::move(ret)); })); } @@ -93,25 +88,28 @@ void StoreClientInternalKV::Put(const std::string &ns, const std::string &key, const std::string &value, bool overwrite, - std::function callback) { - if (!callback) { - callback = [](auto) {}; - } - RAY_CHECK_OK( - delegate_->AsyncPut(table_name_, MakeKey(ns, key), value, overwrite, callback)); + Dispatchable callback) { + RAY_CHECK_OK(delegate_->AsyncPut(table_name_, + MakeKey(ns, key), + value, + overwrite, + [callback = std::move(callback)](bool success) { + callback.DispatchIfNonNull( + "StoreClientInternalKV::Put", success); + })); } void StoreClientInternalKV::Del(const std::string &ns, const std::string &key, bool del_by_prefix, - std::function callback) { - if (!callback) { - callback = [](auto) {}; - } + Dispatchable callback) { + auto dispatch_and_call = callback.AsDispatchedFunction("StoreClientInternalKV::Del"); if (!del_by_prefix) { RAY_CHECK_OK(delegate_->AsyncDelete( - table_name_, MakeKey(ns, key), [callback = std::move(callback)](bool deleted) { - callback(deleted ? 1 : 0); + table_name_, + MakeKey(ns, key), + [dispatch_and_call = std::move(dispatch_and_call)](bool deleted) { + dispatch_and_call(deleted ? 1 : 0); })); return; } @@ -119,32 +117,28 @@ void StoreClientInternalKV::Del(const std::string &ns, RAY_CHECK_OK(delegate_->AsyncGetKeys( table_name_, MakeKey(ns, key), - [this, ns, callback = std::move(callback)](auto keys) { + [this, ns, dispatch_and_call = std::move(dispatch_and_call)](auto keys) { if (keys.empty()) { - callback(0); + dispatch_and_call(0); return; } - RAY_CHECK_OK(delegate_->AsyncBatchDelete(table_name_, keys, std::move(callback))); + RAY_CHECK_OK( + delegate_->AsyncBatchDelete(table_name_, keys, std::move(dispatch_and_call))); })); } void StoreClientInternalKV::Exists(const std::string &ns, const std::string &key, - std::function callback) { - if (!callback) { - callback = [](auto) {}; - } - - RAY_CHECK_OK( - delegate_->AsyncExists(table_name_, MakeKey(ns, key), std::move(callback))); + Dispatchable callback) { + RAY_CHECK_OK(delegate_->AsyncExists( + table_name_, + MakeKey(ns, key), + std::move(callback).AsDispatchedFunction("StoreClientInternalKV::Exists"))); } void StoreClientInternalKV::Keys(const std::string &ns, const std::string &prefix, - std::function)> callback) { - if (!callback) { - callback = [](auto) {}; - } + Dispatchable)> callback) { RAY_CHECK_OK(delegate_->AsyncGetKeys( table_name_, MakeKey(ns, prefix), @@ -154,7 +148,7 @@ void StoreClientInternalKV::Keys(const std::string &ns, for (auto &key : keys) { true_keys.emplace_back(ExtractKey(key)); } - callback(std::move(true_keys)); + callback.DispatchIfNonNull("StoreClientInternalKV::Keys", std::move(true_keys)); })); } diff --git a/src/ray/gcs/gcs_server/store_client_kv.h b/src/ray/gcs/gcs_server/store_client_kv.h index cdbb351df81d..8721102cf8d5 100644 --- a/src/ray/gcs/gcs_server/store_client_kv.h +++ b/src/ray/gcs/gcs_server/store_client_kv.h @@ -15,6 +15,7 @@ #pragma once #include +#include "ray/common/asio/dispatchable.h" #include "ray/gcs/gcs_server/gcs_kv_manager.h" #include "ray/gcs/store_client/store_client.h" @@ -31,31 +32,31 @@ class StoreClientInternalKV : public InternalKVInterface { void Get(const std::string &ns, const std::string &key, - std::function)> callback) override; + Dispatchable)> callback) override; - void MultiGet(const std::string &ns, - const std::vector &keys, - std::function)> - callback) override; + void MultiGet( + const std::string &ns, + const std::vector &keys, + Dispatchable)> callback) override; void Put(const std::string &ns, const std::string &key, const std::string &value, bool overwrite, - std::function callback) override; + Dispatchable callback) override; void Del(const std::string &ns, const std::string &key, bool del_by_prefix, - std::function callback) override; + Dispatchable callback) override; void Exists(const std::string &ns, const std::string &key, - std::function callback) override; + Dispatchable callback) override; void Keys(const std::string &ns, const std::string &prefix, - std::function)> callback) override; + Dispatchable)> callback) override; private: std::unique_ptr delegate_; diff --git a/src/ray/gcs/gcs_server/usage_stats_client.cc b/src/ray/gcs/gcs_server/usage_stats_client.cc index 72ffb77004c8..9eeebde84076 100644 --- a/src/ray/gcs/gcs_server/usage_stats_client.cc +++ b/src/ray/gcs/gcs_server/usage_stats_client.cc @@ -16,20 +16,22 @@ namespace ray { namespace gcs { -UsageStatsClient::UsageStatsClient(InternalKVInterface &internal_kv) - : internal_kv_(internal_kv) {} +UsageStatsClient::UsageStatsClient(InternalKVInterface &internal_kv, + instrumented_io_context &io_context) + : internal_kv_(internal_kv), io_context_(io_context) {} void UsageStatsClient::RecordExtraUsageTag(usage::TagKey key, const std::string &value) { internal_kv_.Put(kUsageStatsNamespace, kExtraUsageTagPrefix + absl::AsciiStrToLower(usage::TagKey_Name(key)), value, /*overwrite=*/true, - [](bool added) { - if (!added) { - RAY_LOG(DEBUG) - << "Did not add new extra usage tag, maybe overwritten"; - } - }); + {[](bool added) { + if (!added) { + RAY_LOG(DEBUG) + << "Did not add new extra usage tag, maybe overwritten"; + } + }, + io_context_}); } void UsageStatsClient::RecordExtraUsageCounter(usage::TagKey key, int64_t counter) { diff --git a/src/ray/gcs/gcs_server/usage_stats_client.h b/src/ray/gcs/gcs_server/usage_stats_client.h index ebddc0047b38..f9ee65566072 100644 --- a/src/ray/gcs/gcs_server/usage_stats_client.h +++ b/src/ray/gcs/gcs_server/usage_stats_client.h @@ -23,7 +23,8 @@ namespace ray { namespace gcs { class UsageStatsClient { public: - explicit UsageStatsClient(ray::gcs::InternalKVInterface &internal_kv); + explicit UsageStatsClient(ray::gcs::InternalKVInterface &internal_kv, + instrumented_io_context &io_context); /// C++ version of record_extra_usage_tag in usage_lib.py /// @@ -40,6 +41,7 @@ class UsageStatsClient { static constexpr char kUsageStatsNamespace[] = "usage_stats"; ray::gcs::InternalKVInterface &internal_kv_; + instrumented_io_context &io_context_; }; } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index 6de20bfe34af..a9db2388427d 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -543,6 +543,7 @@ Status RedisContext::Connect(const std::string &address, std::unique_ptr RedisContext::RunArgvSync( const std::vector &args) { RAY_CHECK(context_); + RAY_CHECK(sync_context_thread_checker_.IsOnSameThread()); // Build the arguments. std::vector argv; std::vector argc; @@ -568,7 +569,9 @@ void RedisContext::RunArgvAsync(std::vector args, std::move(redis_callback), redis_async_context_.get(), std::move(args)); - request_context->Run(); + // If we are already on the io_service thread, we can run the request immediately. + io_service_.dispatch([request_context]() { request_context->Run(); }, + "RedisContext::RunArgvAsync"); } } // namespace gcs diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index f1c5cd87095f..35629afecb98 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -26,6 +26,7 @@ #include "ray/common/status.h" #include "ray/gcs/redis_async_context.h" #include "ray/util/logging.h" +#include "ray/util/thread_checker.h" #include "src/ray/protobuf/gcs.pb.h" extern "C" { @@ -134,9 +135,14 @@ struct RedisRequestContext { std::vector argc_; }; +// RunArgvAsync is thread-safe, can be accessed from multiple threads. The work is +// dispatched to the io_service thread, and the callback is dispatched back to the +// argument io_context. +// +// RunArgvSync is not thread-safe, and CHECK-fails if accessed from multiple threads. class RedisContext { public: - RedisContext(instrumented_io_context &io_service); + explicit RedisContext(instrumented_io_context &io_service); ~RedisContext(); @@ -168,7 +174,7 @@ class RedisContext { RedisAsyncContext &async_context() { RAY_CHECK(redis_async_context_); - return *redis_async_context_.get(); + return *redis_async_context_; } instrumented_io_context &io_service() { return io_service_; } @@ -179,6 +185,9 @@ class RedisContext { std::unique_ptr context_; redisSSLContext *ssl_context_; std::unique_ptr redis_async_context_; + // Checks `context_` is always used in the same thread. No need to check the async + // context because it is thread-safe. + ThreadChecker sync_context_thread_checker_; }; } // namespace gcs From 69a1c487a632f3dc3feab81358c7ca3178258bf4 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 5 Nov 2024 14:04:28 -0800 Subject: [PATCH 17/26] remove copy-as, only move-as Signed-off-by: Ruiyang Wang --- src/ray/common/asio/dispatchable.h | 7 ------- src/ray/gcs/gcs_server/store_client_kv.cc | 3 ++- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/ray/common/asio/dispatchable.h b/src/ray/common/asio/dispatchable.h index 71d812efeee6..b8908c999951 100644 --- a/src/ray/common/asio/dispatchable.h +++ b/src/ray/common/asio/dispatchable.h @@ -32,13 +32,6 @@ class Dispatchable { } } - std::function AsDispatchedFunction(const std::string &name) const & { - auto copied = *this; - return [copied, name](auto &&...args) { - copied.DispatchIfNonNull(name, std::forward(args)...); - }; - } - std::function AsDispatchedFunction(const std::string &name) && { return [moved = std::move(*this), name](auto &&...args) { moved.DispatchIfNonNull(name, std::forward(args)...); diff --git a/src/ray/gcs/gcs_server/store_client_kv.cc b/src/ray/gcs/gcs_server/store_client_kv.cc index 3cfab04e5975..6dd965f85521 100644 --- a/src/ray/gcs/gcs_server/store_client_kv.cc +++ b/src/ray/gcs/gcs_server/store_client_kv.cc @@ -103,7 +103,8 @@ void StoreClientInternalKV::Del(const std::string &ns, const std::string &key, bool del_by_prefix, Dispatchable callback) { - auto dispatch_and_call = callback.AsDispatchedFunction("StoreClientInternalKV::Del"); + auto dispatch_and_call = + std::move(callback).AsDispatchedFunction("StoreClientInternalKV::Del"); if (!del_by_prefix) { RAY_CHECK_OK(delegate_->AsyncDelete( table_name_, From 97bc8eac3e525b427875325feba657e8d48d1c7d Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Thu, 21 Nov 2024 16:55:17 -0800 Subject: [PATCH 18/26] wip big postable Signed-off-by: Ruiyang Wang --- .../execution/operators/output_splitter.py | 2 +- src/ray/common/BUILD | 2 +- src/ray/common/asio/dispatchable.h | 64 ------- src/ray/gcs/gcs_server/gcs_function_manager.h | 12 +- src/ray/gcs/gcs_server/gcs_kv_manager.h | 14 +- src/ray/gcs/gcs_server/gcs_table_storage.cc | 53 +++--- src/ray/gcs/gcs_server/gcs_table_storage.h | 28 +-- src/ray/gcs/gcs_server/store_client_kv.cc | 43 ++--- src/ray/gcs/gcs_server/store_client_kv.h | 14 +- .../store_client/in_memory_store_client.cc | 63 ++----- .../gcs/store_client/in_memory_store_client.h | 28 ++- .../store_client/observable_store_client.cc | 91 ++++------ .../store_client/observable_store_client.h | 21 +-- .../gcs/store_client/redis_store_client.cc | 170 +++++++++--------- src/ray/gcs/store_client/redis_store_client.h | 47 +++-- src/ray/gcs/store_client/store_client.h | 25 +-- 16 files changed, 278 insertions(+), 399 deletions(-) delete mode 100644 src/ray/common/asio/dispatchable.h diff --git a/python/ray/data/_internal/execution/operators/output_splitter.py b/python/ray/data/_internal/execution/operators/output_splitter.py index f5a9b6c55d84..833557f6e203 100644 --- a/python/ray/data/_internal/execution/operators/output_splitter.py +++ b/python/ray/data/_internal/execution/operators/output_splitter.py @@ -161,7 +161,7 @@ def progress_str(self) -> str: def _dispatch_bundles(self, dispatch_all: bool = False) -> None: start_time = time.perf_counter() - # Dispatch all dispatchable bundles from the internal buffer. + # Dispatch all postable bundles from the internal buffer. # This may not dispatch all bundles when equal=True. while self._buffer and ( dispatch_all or len(self._buffer) >= self._min_buffer_size diff --git a/src/ray/common/BUILD b/src/ray/common/BUILD index 3d8c118eea34..94d794883fe0 100644 --- a/src/ray/common/BUILD +++ b/src/ray/common/BUILD @@ -196,7 +196,7 @@ ray_cc_library( hdrs = [ "asio/asio_chaos.h", "asio/asio_util.h", - "asio/dispatchable.h", + "asio/postable.h", "asio/instrumented_io_context.h", "asio/io_service_pool.h", "asio/periodical_runner.h", diff --git a/src/ray/common/asio/dispatchable.h b/src/ray/common/asio/dispatchable.h deleted file mode 100644 index b8908c999951..000000000000 --- a/src/ray/common/asio/dispatchable.h +++ /dev/null @@ -1,64 +0,0 @@ -#pragma once -#include -#include - -#include "ray/common/asio/instrumented_io_context.h" - -namespace ray { - -// Wrapper for a std::function that includes an instrumented_io_context for dispatching. -// On `Dispatch`, the function is called with the provided arguments, dispatched onto the -// provided io_context. -// -// The io_context must outlive the Dispatchable object. -template -class Dispatchable { - public: - Dispatchable(std::function func, instrumented_io_context &io_context) - : func_(std::move(func)), io_context_(&io_context) {} - - explicit Dispatchable(nullptr_t) : func_(nullptr), io_context_(nullptr) {} - - template - void DispatchIfNonNull(const std::string &name, Args &&...args) const { - if (func_) { - RAY_CHECK(io_context_ != nullptr); - io_context_->dispatch( - [func = func_, - args_tuple = std::make_tuple(std::forward(args)...)]() mutable { - std::apply(func, std::move(args_tuple)); - }, - name); - } - } - - std::function AsDispatchedFunction(const std::string &name) && { - return [moved = std::move(*this), name](auto &&...args) { - moved.DispatchIfNonNull(name, std::forward(args)...); - }; - } - - bool operator==(std::nullptr_t) const { return func_ == nullptr; } - bool operator!=(std::nullptr_t) const { return func_ != nullptr; } - - std::function func_; - instrumented_io_context *const io_context_; -}; - -namespace internal { - -template -struct ToDispatchableHelper; - -template -struct ToDispatchableHelper> { - using type = Dispatchable; -}; - -} // namespace internal - -// using ToDispatchable> = Dispatchable; -template -using ToDispatchable = typename internal::ToDispatchableHelper::type; - -} // namespace ray diff --git a/src/ray/gcs/gcs_server/gcs_function_manager.h b/src/ray/gcs/gcs_server/gcs_function_manager.h index fe6b13c070ba..ac05c98671f1 100644 --- a/src/ray/gcs/gcs_server/gcs_function_manager.h +++ b/src/ray/gcs/gcs_server/gcs_function_manager.h @@ -46,18 +46,12 @@ class GcsFunctionManager { private: void RemoveExportedFunctions(const JobID &job_id) { auto job_id_hex = job_id.Hex(); - kv_.Del("fun", - "RemoteFunction:" + job_id_hex + ":", - true, - Dispatchable{nullptr}); - kv_.Del("fun", - "ActorClass:" + job_id_hex + ":", - true, - Dispatchable{nullptr}); + kv_.Del("fun", "RemoteFunction:" + job_id_hex + ":", true, Postable{}); + kv_.Del("fun", "ActorClass:" + job_id_hex + ":", true, Postable{}); kv_.Del("fun", std::string(kWorkerSetupHookKeyName) + ":" + job_id_hex + ":", true, - Dispatchable{nullptr}); + Postable{}); } // Handler for internal KV diff --git a/src/ray/gcs/gcs_server/gcs_kv_manager.h b/src/ray/gcs/gcs_server/gcs_kv_manager.h index d6b24ecaea21..887ff4e73bda 100644 --- a/src/ray/gcs/gcs_server/gcs_kv_manager.h +++ b/src/ray/gcs/gcs_server/gcs_kv_manager.h @@ -17,7 +17,7 @@ #include "absl/container/btree_map.h" #include "absl/synchronization/mutex.h" -#include "ray/common/asio/dispatchable.h" +#include "ray/common/asio/postable.h" #include "ray/gcs/redis_client.h" #include "ray/gcs/store_client/redis_store_client.h" #include "ray/rpc/gcs_server/gcs_rpc_server.h" @@ -38,7 +38,7 @@ class InternalKVInterface { /// \param callback Returns the value or null if the key doesn't exist. virtual void Get(const std::string &ns, const std::string &key, - Dispatchable)> callback) = 0; + Postable)> callback) = 0; /// Get the values associated with `keys`. /// @@ -48,7 +48,7 @@ class InternalKVInterface { virtual void MultiGet( const std::string &ns, const std::vector &keys, - Dispatchable)> callback) = 0; + Postable)> callback) = 0; /// Associate a key with the specified value. /// @@ -63,7 +63,7 @@ class InternalKVInterface { const std::string &key, const std::string &value, bool overwrite, - Dispatchable callback) = 0; + Postable callback) = 0; /// Delete the key from the store. /// @@ -75,7 +75,7 @@ class InternalKVInterface { virtual void Del(const std::string &ns, const std::string &key, bool del_by_prefix, - Dispatchable callback) = 0; + Postable callback) = 0; /// Check whether the key exists in the store. /// @@ -84,7 +84,7 @@ class InternalKVInterface { /// \param callback Callback function. virtual void Exists(const std::string &ns, const std::string &key, - Dispatchable callback) = 0; + Postable callback) = 0; /// Get the keys for a given prefix. /// @@ -93,7 +93,7 @@ class InternalKVInterface { /// \param callback return all the keys matching the prefix. virtual void Keys(const std::string &ns, const std::string &prefix, - Dispatchable)> callback) = 0; + Postable)> callback) = 0; virtual ~InternalKVInterface() = default; }; diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.cc b/src/ray/gcs/gcs_server/gcs_table_storage.cc index 12a13e37d928..7464aeeb8059 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.cc +++ b/src/ray/gcs/gcs_server/gcs_table_storage.cc @@ -24,21 +24,18 @@ namespace gcs { template Status GcsTable::Put(const Key &key, const Data &value, - const StatusCallback &callback) { - return store_client_->AsyncPut(table_name_, - key.Binary(), - value.SerializeAsString(), - /*overwrite*/ true, - [callback](auto) { - if (callback) { - callback(Status::OK()); - } - }); + Postable callback) { + return store_client_->AsyncPut( + table_name_, + key.Binary(), + value.SerializeAsString(), + /*overwrite*/ true, + std::move(callback).Compose([](bool) { return Status::OK(); })); } template Status GcsTable::Get(const Key &key, - const OptionalItemCallback &callback) { + ToPostable> callback) { auto on_done = [callback](const Status &status, const std::optional &result) { if (!callback) { @@ -52,11 +49,23 @@ Status GcsTable::Get(const Key &key, } callback(status, std::move(value)); }; - return store_client_->AsyncGet(table_name_, key.Binary(), on_done); + return store_client_->AsyncGet( + table_name_, + key.Binary(), + std::move(callback).Compose( + [](const Status &status, const std::optional &result) { + std::optional value; + if (result) { + Data data; + data.ParseFromString(*result); + value = std::move(data); + } + re + })); } template -Status GcsTable::GetAll(const MapCallback &callback) { +Status GcsTable::GetAll(ToPostable> callback) { auto on_done = [callback](absl::flat_hash_map &&result) { if (!callback) { return; @@ -74,7 +83,7 @@ Status GcsTable::GetAll(const MapCallback &callback) { } template -Status GcsTable::Delete(const Key &key, const StatusCallback &callback) { +Status GcsTable::Delete(const Key &key, Postable callback) { return store_client_->AsyncDelete(table_name_, key.Binary(), [callback](auto) { if (callback) { callback(Status::OK()); @@ -84,7 +93,7 @@ Status GcsTable::Delete(const Key &key, const StatusCallback &callbac template Status GcsTable::BatchDelete(const std::vector &keys, - const StatusCallback &callback) { + Postable callback) { std::vector keys_to_delete; keys_to_delete.reserve(keys.size()); for (auto &key : keys) { @@ -101,7 +110,7 @@ Status GcsTable::BatchDelete(const std::vector &keys, template Status GcsTableWithJobId::Put(const Key &key, const Data &value, - const StatusCallback &callback) { + Postable callback) { { absl::MutexLock lock(&mutex_); index_[GetJobIdFromKey(key)].insert(key); @@ -119,8 +128,8 @@ Status GcsTableWithJobId::Put(const Key &key, } template -Status GcsTableWithJobId::GetByJobId(const JobID &job_id, - const MapCallback &callback) { +Status GcsTableWithJobId::GetByJobId( + const JobID &job_id, ToPostable> callback) { std::vector keys; { absl::MutexLock lock(&mutex_); @@ -146,7 +155,7 @@ Status GcsTableWithJobId::GetByJobId(const JobID &job_id, template Status GcsTableWithJobId::DeleteByJobId(const JobID &job_id, - const StatusCallback &callback) { + Postable callback) { std::vector keys; { absl::MutexLock lock(&mutex_); @@ -160,13 +169,13 @@ Status GcsTableWithJobId::DeleteByJobId(const JobID &job_id, template Status GcsTableWithJobId::Delete(const Key &key, - const StatusCallback &callback) { + Postable callback) { return BatchDelete({key}, callback); } template Status GcsTableWithJobId::BatchDelete(const std::vector &keys, - const StatusCallback &callback) { + Postable callback) { std::vector keys_to_delete; keys_to_delete.reserve(keys.size()); for (auto &key : keys) { @@ -188,7 +197,7 @@ Status GcsTableWithJobId::BatchDelete(const std::vector &keys, template Status GcsTableWithJobId::AsyncRebuildIndexAndGetAll( - const MapCallback &callback) { + ToPostable> callback) { return this->GetAll([this, callback](absl::flat_hash_map &&result) mutable { absl::MutexLock lock(&mutex_); index_.clear(); diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.h b/src/ray/gcs/gcs_server/gcs_table_storage.h index bc5cd581d1aa..0eab3b061f13 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.h +++ b/src/ray/gcs/gcs_server/gcs_table_storage.h @@ -56,27 +56,27 @@ class GcsTable { /// \param value The value of the key that will be written to the table. /// \param callback Callback that will be called after write finishes. /// \return Status - virtual Status Put(const Key &key, const Data &value, const StatusCallback &callback); + virtual Status Put(const Key &key, const Data &value, Postable callback); /// Get data from the table asynchronously. /// /// \param key The key to lookup from the table. /// \param callback Callback that will be called after read finishes. /// \return Status - Status Get(const Key &key, const OptionalItemCallback &callback); + Status Get(const Key &key, ToPostable> callback); /// Get all data from the table asynchronously. /// /// \param callback Callback that will be called after data has been received. /// \return Status - Status GetAll(const MapCallback &callback); + Status GetAll(ToPostable> callback); /// Delete data from the table asynchronously. /// /// \param key The key that will be deleted from the table. /// \param callback Callback that will be called after delete finishes. /// \return Status - virtual Status Delete(const Key &key, const StatusCallback &callback); + virtual Status Delete(const Key &key, Postable callback); /// Delete a batch of data from the table asynchronously. /// @@ -84,7 +84,7 @@ class GcsTable { /// \param callback Callback that will be called after delete finishes. /// \return Status virtual Status BatchDelete(const std::vector &keys, - const StatusCallback &callback); + Postable callback); protected: std::string table_name_; @@ -113,28 +113,28 @@ class GcsTableWithJobId : public GcsTable { /// \param value The value of the key that will be written to the table. /// \param callback Callback that will be called after write finishes, whether it /// succeeds or not. \return Status for issuing the asynchronous write operation. - Status Put(const Key &key, const Data &value, const StatusCallback &callback) override; + Status Put(const Key &key, const Data &value, Postable callback) override; /// Get all the data of the specified job id from the table asynchronously. /// /// \param job_id The key to lookup from the table. /// \param callback Callback that will be called after read finishes. /// \return Status - Status GetByJobId(const JobID &job_id, const MapCallback &callback); + Status GetByJobId(const JobID &job_id, ToPostable> callback); /// Delete all the data of the specified job id from the table asynchronously. /// /// \param job_id The key that will be deleted from the table. /// \param callback Callback that will be called after delete finishes. /// \return Status - Status DeleteByJobId(const JobID &job_id, const StatusCallback &callback); + Status DeleteByJobId(const JobID &job_id, Postable callback); /// Delete data and index from the table asynchronously. /// /// \param key The key that will be deleted from the table. /// \param callback Callback that will be called after delete finishes. /// \return Status - Status Delete(const Key &key, const StatusCallback &callback) override; + Status Delete(const Key &key, Postable callback) override; /// Delete a batch of data and index from the table asynchronously. /// @@ -142,10 +142,10 @@ class GcsTableWithJobId : public GcsTable { /// \param callback Callback that will be called after delete finishes. /// \return Status Status BatchDelete(const std::vector &keys, - const StatusCallback &callback) override; + Postable callback) override; /// Rebuild the index during startup. - Status AsyncRebuildIndexAndGetAll(const MapCallback &callback); + Status AsyncRebuildIndexAndGetAll(ToPostable> callback); protected: virtual JobID GetJobIdFromKey(const Key &key) = 0; @@ -304,7 +304,11 @@ class InMemoryGcsTableStorage : public GcsTableStorage { public: explicit InMemoryGcsTableStorage(instrumented_io_context &main_io_service) : GcsTableStorage(std::make_shared( - std::make_unique(main_io_service))) {} + std::make_unique()) + ), io_service_(main_io_service) {} + + // All methods are posted to this io_service. + instrumented_io_context &io_service_; }; } // namespace gcs diff --git a/src/ray/gcs/gcs_server/store_client_kv.cc b/src/ray/gcs/gcs_server/store_client_kv.cc index 6dd965f85521..95e3670f8148 100644 --- a/src/ray/gcs/gcs_server/store_client_kv.cc +++ b/src/ray/gcs/gcs_server/store_client_kv.cc @@ -53,22 +53,22 @@ StoreClientInternalKV::StoreClientInternalKV(std::unique_ptr store_ void StoreClientInternalKV::Get(const std::string &ns, const std::string &key, - Dispatchable)> callback) { + Postable)> callback) { RAY_CHECK_OK(delegate_->AsyncGet( table_name_, MakeKey(ns, key), [callback = std::move(callback)](auto status, auto result) { - callback.DispatchIfNonNull("StoreClientInternalKV::Get", - result.has_value() - ? std::optional(result.value()) - : std::optional()); + callback.PostIfNonNull("StoreClientInternalKV::Get", + result.has_value() + ? std::optional(result.value()) + : std::optional()); })); } void StoreClientInternalKV::MultiGet( const std::string &ns, const std::vector &keys, - Dispatchable)> callback) { + Postable)> callback) { std::vector prefixed_keys; prefixed_keys.reserve(keys.size()); for (const auto &key : keys) { @@ -80,7 +80,7 @@ void StoreClientInternalKV::MultiGet( for (const auto &item : result) { ret.emplace(ExtractKey(item.first), item.second); } - callback.DispatchIfNonNull("StoreClientInternalKV::MultiGet", std::move(ret)); + callback.PostIfNonNull("StoreClientInternalKV::MultiGet", std::move(ret)); })); } @@ -88,21 +88,21 @@ void StoreClientInternalKV::Put(const std::string &ns, const std::string &key, const std::string &value, bool overwrite, - Dispatchable callback) { + Postable callback) { RAY_CHECK_OK(delegate_->AsyncPut(table_name_, MakeKey(ns, key), value, overwrite, [callback = std::move(callback)](bool success) { - callback.DispatchIfNonNull( - "StoreClientInternalKV::Put", success); + callback.PostIfNonNull("StoreClientInternalKV::Put", + success); })); } void StoreClientInternalKV::Del(const std::string &ns, const std::string &key, bool del_by_prefix, - Dispatchable callback) { + Postable callback) { auto dispatch_and_call = std::move(callback).AsDispatchedFunction("StoreClientInternalKV::Del"); if (!del_by_prefix) { @@ -130,7 +130,7 @@ void StoreClientInternalKV::Del(const std::string &ns, void StoreClientInternalKV::Exists(const std::string &ns, const std::string &key, - Dispatchable callback) { + Postable callback) { RAY_CHECK_OK(delegate_->AsyncExists( table_name_, MakeKey(ns, key), @@ -139,18 +139,19 @@ void StoreClientInternalKV::Exists(const std::string &ns, void StoreClientInternalKV::Keys(const std::string &ns, const std::string &prefix, - Dispatchable)> callback) { + Postable)> callback) { RAY_CHECK_OK(delegate_->AsyncGetKeys( table_name_, MakeKey(ns, prefix), - [callback = std::move(callback)](std::vector keys) { - std::vector true_keys; - true_keys.reserve(keys.size()); - for (auto &key : keys) { - true_keys.emplace_back(ExtractKey(key)); - } - callback.DispatchIfNonNull("StoreClientInternalKV::Keys", std::move(true_keys)); - })); + std::move(callback).Compose( + [](const std::vector &keys) -> std::vector { + std::vector true_keys; + true_keys.reserve(keys.size()); + for (auto &key : keys) { + true_keys.emplace_back(ExtractKey(key)); + } + return true_keys; + }))); } } // namespace gcs diff --git a/src/ray/gcs/gcs_server/store_client_kv.h b/src/ray/gcs/gcs_server/store_client_kv.h index 8721102cf8d5..3b103725cc94 100644 --- a/src/ray/gcs/gcs_server/store_client_kv.h +++ b/src/ray/gcs/gcs_server/store_client_kv.h @@ -15,7 +15,7 @@ #pragma once #include -#include "ray/common/asio/dispatchable.h" +#include "ray/common/asio/postable.h" #include "ray/gcs/gcs_server/gcs_kv_manager.h" #include "ray/gcs/store_client/store_client.h" @@ -32,31 +32,31 @@ class StoreClientInternalKV : public InternalKVInterface { void Get(const std::string &ns, const std::string &key, - Dispatchable)> callback) override; + Postable)> callback) override; void MultiGet( const std::string &ns, const std::vector &keys, - Dispatchable)> callback) override; + Postable)> callback) override; void Put(const std::string &ns, const std::string &key, const std::string &value, bool overwrite, - Dispatchable callback) override; + Postable callback) override; void Del(const std::string &ns, const std::string &key, bool del_by_prefix, - Dispatchable callback) override; + Postable callback) override; void Exists(const std::string &ns, const std::string &key, - Dispatchable callback) override; + Postable callback) override; void Keys(const std::string &ns, const std::string &prefix, - Dispatchable)> callback) override; + Postable)> callback) override; private: std::unique_ptr delegate_; diff --git a/src/ray/gcs/store_client/in_memory_store_client.cc b/src/ray/gcs/store_client/in_memory_store_client.cc index 39306b1254c9..85ec2208181e 100644 --- a/src/ray/gcs/store_client/in_memory_store_client.cc +++ b/src/ray/gcs/store_client/in_memory_store_client.cc @@ -22,7 +22,7 @@ Status InMemoryStoreClient::AsyncPut(const std::string &table_name, const std::string &key, const std::string &data, bool overwrite, - std::function callback) { + Postable callback) { auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); auto it = table->records_.find(key); @@ -35,17 +35,14 @@ Status InMemoryStoreClient::AsyncPut(const std::string &table_name, table->records_[key] = data; inserted = true; } - if (callback != nullptr) { - main_io_service_.post([callback, inserted]() { callback(inserted); }, - "GcsInMemoryStore.Put"); - } + callback.Post("GcsInMemoryStore.Put", inserted); return Status::OK(); } -Status InMemoryStoreClient::AsyncGet(const std::string &table_name, - const std::string &key, - const OptionalItemCallback &callback) { - RAY_CHECK(callback != nullptr); +Status InMemoryStoreClient::AsyncGet( + const std::string &table_name, + const std::string &key, + ToPostable> callback) { auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); auto iter = table->records_.find(key); @@ -53,34 +50,25 @@ Status InMemoryStoreClient::AsyncGet(const std::string &table_name, if (iter != table->records_.end()) { data = iter->second; } - - main_io_service_.post( - [callback, data = std::move(data)]() mutable // allow data to be moved - { callback(Status::OK(), std::move(data)); }, - "GcsInMemoryStore.Get"); - + callback.Post("GcsInMemoryStore.Get", Status::OK(), std::move(data)); return Status::OK(); } Status InMemoryStoreClient::AsyncGetAll( const std::string &table_name, - const MapCallback &callback) { - RAY_CHECK(callback); + ToPostable> callback) { auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); auto result = absl::flat_hash_map(); result.insert(table->records_.begin(), table->records_.end()); - main_io_service_.post( - [result = std::move(result), callback]() mutable { callback(std::move(result)); }, - "GcsInMemoryStore.GetAll"); + callback.Post("GcsInMemoryStore.GetAll", std::move(result)); return Status::OK(); } Status InMemoryStoreClient::AsyncMultiGet( const std::string &table_name, const std::vector &keys, - const MapCallback &callback) { - RAY_CHECK(callback); + ToPostable> callback) { auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); auto result = absl::flat_hash_map(); @@ -91,38 +79,30 @@ Status InMemoryStoreClient::AsyncMultiGet( } result[key] = it->second; } - main_io_service_.post( - [result = std::move(result), callback]() mutable { callback(std::move(result)); }, - "GcsInMemoryStore.GetAll"); + callback.Post("GcsInMemoryStore.GetAll", std::move(result)); return Status::OK(); } Status InMemoryStoreClient::AsyncDelete(const std::string &table_name, const std::string &key, - std::function callback) { + Postable callback) { auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); auto num = table->records_.erase(key); - if (callback != nullptr) { - main_io_service_.post([callback, num]() { callback(num > 0); }, - "GcsInMemoryStore.Delete"); - } + callback.Post("GcsInMemoryStore.Delete", num > 0); return Status::OK(); } Status InMemoryStoreClient::AsyncBatchDelete(const std::string &table_name, const std::vector &keys, - std::function callback) { + Postable callback) { auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); int64_t num = 0; for (auto &key : keys) { num += table->records_.erase(key); } - if (callback != nullptr) { - main_io_service_.post([callback, num]() { callback(num); }, - "GcsInMemoryStore.BatchDelete"); - } + callback.Post("GcsInMemoryStore.BatchDelete", num); return Status::OK(); } @@ -148,8 +128,7 @@ std::shared_ptr InMemoryStoreClient::GetOrCr Status InMemoryStoreClient::AsyncGetKeys( const std::string &table_name, const std::string &prefix, - std::function)> callback) { - RAY_CHECK(callback); + Postable)> callback) { auto table = GetOrCreateTable(table_name); std::vector result; absl::MutexLock lock(&(table->mutex_)); @@ -158,21 +137,17 @@ Status InMemoryStoreClient::AsyncGetKeys( result.push_back(pair.first); } } - main_io_service_.post( - [result = std::move(result), callback]() mutable { callback(std::move(result)); }, - "GcsInMemoryStore.Keys"); + callback.Post("GcsInMemoryStore.Keys", std::move(result)); return Status::OK(); } Status InMemoryStoreClient::AsyncExists(const std::string &table_name, const std::string &key, - std::function callback) { - RAY_CHECK(callback); + Postable callback) { auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); bool result = table->records_.contains(key); - main_io_service_.post([result, callback]() mutable { callback(result); }, - "GcsInMemoryStore.Exists"); + callback.Post("GcsInMemoryStore.Exists", result); return Status::OK(); } diff --git a/src/ray/gcs/store_client/in_memory_store_client.h b/src/ray/gcs/store_client/in_memory_store_client.h index a4ea7bc47ac6..612ab8577feb 100644 --- a/src/ray/gcs/store_client/in_memory_store_client.h +++ b/src/ray/gcs/store_client/in_memory_store_client.h @@ -30,43 +30,43 @@ namespace gcs { /// This class is thread safe. class InMemoryStoreClient : public StoreClient { public: - explicit InMemoryStoreClient(instrumented_io_context &main_io_service) - : main_io_service_(main_io_service) {} + explicit InMemoryStoreClient() = default; Status AsyncPut(const std::string &table_name, const std::string &key, const std::string &data, bool overwrite, - std::function callback) override; + Postable callback) override; Status AsyncGet(const std::string &table_name, const std::string &key, - const OptionalItemCallback &callback) override; + ToPostable> callback) override; Status AsyncGetAll(const std::string &table_name, - const MapCallback &callback) override; + ToPostable> callback) override; - Status AsyncMultiGet(const std::string &table_name, - const std::vector &keys, - const MapCallback &callback) override; + Status AsyncMultiGet( + const std::string &table_name, + const std::vector &keys, + ToPostable> callback) override; Status AsyncDelete(const std::string &table_name, const std::string &key, - std::function callback) override; + Postable callback) override; Status AsyncBatchDelete(const std::string &table_name, const std::vector &keys, - std::function callback) override; + Postable callback) override; int GetNextJobID() override; Status AsyncGetKeys(const std::string &table_name, const std::string &prefix, - std::function)> callback) override; + Postable)> callback) override; Status AsyncExists(const std::string &table_name, const std::string &key, - std::function callback) override; + Postable callback) override; private: struct InMemoryTable { @@ -84,10 +84,6 @@ class InMemoryStoreClient : public StoreClient { absl::flat_hash_map> tables_ ABSL_GUARDED_BY(mutex_); - /// Async API Callback needs to post to main_io_service_ to ensure the orderly execution - /// of the callback. - instrumented_io_context &main_io_service_; - int job_id_ = 0; }; diff --git a/src/ray/gcs/store_client/observable_store_client.cc b/src/ray/gcs/store_client/observable_store_client.cc index 147c9191a824..98a5371431ca 100644 --- a/src/ray/gcs/store_client/observable_store_client.cc +++ b/src/ray/gcs/store_client/observable_store_client.cc @@ -26,101 +26,80 @@ Status ObservableStoreClient::AsyncPut(const std::string &table_name, const std::string &key, const std::string &data, bool overwrite, - std::function callback) { + Postable callback) { auto start = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_count.Record(1, "Put"); return delegate_->AsyncPut( - table_name, - key, - data, - overwrite, - [start, callback = std::move(callback)](auto result) { + table_name, key, data, overwrite, std::move(callback).OnInvocation([start]() { auto end = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_latency_ms.Record( absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "Put"); - if (callback) { - callback(std::move(result)); - } - }); + })); } Status ObservableStoreClient::AsyncGet( const std::string &table_name, const std::string &key, - const OptionalItemCallback &callback) { + ToPostable> callback) { auto start = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_count.Record(1, "Get"); - return delegate_->AsyncGet( - table_name, key, [start, callback](auto status, auto result) { - auto end = absl::GetCurrentTimeNanos(); - STATS_gcs_storage_operation_latency_ms.Record( - absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "Get"); - if (callback) { - callback(status, std::move(result)); - } - }); + return delegate_->AsyncGet(table_name, key, std::move(callback).OnInvocation([start]() { + auto end = absl::GetCurrentTimeNanos(); + STATS_gcs_storage_operation_latency_ms.Record( + absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "Get"); + })); } Status ObservableStoreClient::AsyncGetAll( const std::string &table_name, - const MapCallback &callback) { + ToPostable> callback) { auto start = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_count.Record(1, "GetAll"); - return delegate_->AsyncGetAll(table_name, [start, callback](auto result) { + return delegate_->AsyncGetAll(table_name, std::move(callback).OnInvocation([start]() { auto end = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_latency_ms.Record( absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "GetAll"); - if (callback) { - callback(std::move(result)); - } - }); + })); } + Status ObservableStoreClient::AsyncMultiGet( const std::string &table_name, const std::vector &keys, - const MapCallback &callback) { + ToPostable> callback) { auto start = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_count.Record(1, "MultiGet"); - return delegate_->AsyncMultiGet(table_name, keys, [start, callback](auto result) { - auto end = absl::GetCurrentTimeNanos(); - STATS_gcs_storage_operation_latency_ms.Record( - absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "MultiGet"); - if (callback) { - callback(std::move(result)); - } - }); + return delegate_->AsyncMultiGet( + table_name, keys, std::move(callback).OnInvocation([start]() { + auto end = absl::GetCurrentTimeNanos(); + STATS_gcs_storage_operation_latency_ms.Record( + absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "MultiGet"); + })); } Status ObservableStoreClient::AsyncDelete(const std::string &table_name, const std::string &key, - std::function callback) { + Postable callback) { auto start = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_count.Record(1, "Delete"); return delegate_->AsyncDelete( - table_name, key, [start, callback = std::move(callback)](auto result) { + table_name, key, std::move(callback).OnInvocation([start]() { auto end = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_latency_ms.Record( absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "Delete"); - if (callback) { - callback(std::move(result)); - } - }); + })); } Status ObservableStoreClient::AsyncBatchDelete(const std::string &table_name, const std::vector &keys, - std::function callback) { + Postable callback) { auto start = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_count.Record(1, "BatchDelete"); return delegate_->AsyncBatchDelete( - table_name, keys, [start, callback = std::move(callback)](auto result) { + table_name, keys, std::move(callback).OnInvocation([start]() { auto end = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_latency_ms.Record( absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "BatchDelete"); - if (callback) { - callback(std::move(result)); - } - }); + })); } int ObservableStoreClient::GetNextJobID() { return delegate_->GetNextJobID(); } @@ -128,34 +107,28 @@ int ObservableStoreClient::GetNextJobID() { return delegate_->GetNextJobID(); } Status ObservableStoreClient::AsyncGetKeys( const std::string &table_name, const std::string &prefix, - std::function)> callback) { + Postable)> callback) { auto start = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_count.Record(1, "GetKeys"); return delegate_->AsyncGetKeys( - table_name, prefix, [start, callback = std::move(callback)](auto result) { + table_name, prefix, std::move(callback).OnInvocation([start]() { auto end = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_latency_ms.Record( absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "GetKeys"); - if (callback) { - callback(std::move(result)); - } - }); + })); } Status ObservableStoreClient::AsyncExists(const std::string &table_name, const std::string &key, - std::function callback) { + Postable callback) { auto start = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_count.Record(1, "Exists"); return delegate_->AsyncExists( - table_name, key, [start, callback = std::move(callback)](auto result) { + table_name, key, std::move(callback).OnInvocation([start]() { auto end = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_latency_ms.Record( absl::ToDoubleMilliseconds(absl::Nanoseconds(end - start)), "Exists"); - if (callback) { - callback(std::move(result)); - } - }); + })); } } // namespace gcs diff --git a/src/ray/gcs/store_client/observable_store_client.h b/src/ray/gcs/store_client/observable_store_client.h index c6dd756b3ef1..388ad7cc3271 100644 --- a/src/ray/gcs/store_client/observable_store_client.h +++ b/src/ray/gcs/store_client/observable_store_client.h @@ -30,36 +30,37 @@ class ObservableStoreClient : public StoreClient { const std::string &key, const std::string &data, bool overwrite, - std::function callback) override; + Postable callback) override; Status AsyncGet(const std::string &table_name, const std::string &key, - const OptionalItemCallback &callback) override; + ToPostable> callback) override; Status AsyncGetAll(const std::string &table_name, - const MapCallback &callback) override; + ToPostable> callback) override; - Status AsyncMultiGet(const std::string &table_name, - const std::vector &keys, - const MapCallback &callback) override; + Status AsyncMultiGet( + const std::string &table_name, + const std::vector &keys, + ToPostable> callback) override; Status AsyncDelete(const std::string &table_name, const std::string &key, - std::function callback) override; + Postable callback) override; Status AsyncBatchDelete(const std::string &table_name, const std::vector &keys, - std::function callback) override; + Postable callback) override; int GetNextJobID() override; Status AsyncGetKeys(const std::string &table_name, const std::string &prefix, - std::function)> callback) override; + Postable)> callback) override; Status AsyncExists(const std::string &table_name, const std::string &key, - std::function callback) override; + Postable callback) override; private: std::unique_ptr delegate_; diff --git a/src/ray/gcs/store_client/redis_store_client.cc b/src/ray/gcs/store_client/redis_store_client.cc index 08a84921d566..634ec06ce782 100644 --- a/src/ray/gcs/store_client/redis_store_client.cc +++ b/src/ray/gcs/store_client/redis_store_client.cc @@ -71,41 +71,6 @@ RedisMatchPattern RedisMatchPattern::Prefix(const std::string &prefix) { return RedisMatchPattern(absl::StrCat(EscapeMatchPattern(prefix), "*")); } -void RedisStoreClient::MGetValues(const std::string &table_name, - const std::vector &keys, - const MapCallback &callback) { - // The `HMGET` command for each shard. - auto batched_commands = GenCommandsBatched( - "HMGET", RedisKey{external_storage_namespace_, table_name}, keys); - auto total_count = batched_commands.size(); - auto finished_count = std::make_shared(0); - auto key_value_map = std::make_shared>(); - - for (auto &command : batched_commands) { - auto mget_callback = [finished_count, - total_count, - // Copies! - args = command.args, - callback, - key_value_map](const std::shared_ptr &reply) { - if (!reply->IsNil()) { - auto value = reply->ReadAsStringArray(); - for (size_t index = 0; index < value.size(); ++index) { - if (value[index].has_value()) { - (*key_value_map)[args[index]] = *(value[index]); - } - } - } - - ++(*finished_count); - if (*finished_count == total_count) { - callback(std::move(*key_value_map)); - } - }; - SendRedisCmdArgsAsKeys(std::move(command), std::move(mget_callback)); - } -} - RedisStoreClient::RedisStoreClient(std::shared_ptr redis_client) : external_storage_namespace_(::RayConfig::instance().external_storage_namespace()), redis_client_(std::move(redis_client)) { @@ -118,35 +83,32 @@ Status RedisStoreClient::AsyncPut(const std::string &table_name, const std::string &key, const std::string &data, bool overwrite, - std::function callback) { + Postable callback) { RedisCommand command{/*command=*/overwrite ? "HSET" : "HSETNX", RedisKey{external_storage_namespace_, table_name}, /*args=*/{key, data}}; - RedisCallback write_callback = nullptr; - if (callback) { - write_callback = - [callback = std::move(callback)](const std::shared_ptr &reply) { - auto added_num = reply->ReadAsInteger(); - callback(added_num != 0); - }; - } + RedisCallback write_callback = + [callback = std::move(callback)](const std::shared_ptr &reply) { + auto added_num = reply->ReadAsInteger(); + callback.Post("RedisStoreClient.AsyncPut", added_num != 0); + }; SendRedisCmdWithKeys({key}, std::move(command), std::move(write_callback)); return Status::OK(); } -Status RedisStoreClient::AsyncGet(const std::string &table_name, - const std::string &key, - const OptionalItemCallback &callback) { - RAY_CHECK(callback != nullptr); - - auto redis_callback = [callback](const std::shared_ptr &reply) { +Status RedisStoreClient::AsyncGet( + const std::string &table_name, + const std::string &key, + ToPostable> callback) { + auto redis_callback = [callback = std::move(callback)]( + const std::shared_ptr &reply) { std::optional result; if (!reply->IsNil()) { result = reply->ReadAsString(); } RAY_CHECK(!reply->IsError()) << "Failed to get from Redis with status: " << reply->ReadAsStatus(); - callback(Status::OK(), std::move(result)); + callback.Post("RedisStoreClient.AsyncGet", Status::OK(), std::move(result)); }; RedisCommand command{/*command=*/"HGET", @@ -158,8 +120,7 @@ Status RedisStoreClient::AsyncGet(const std::string &table_name, Status RedisStoreClient::AsyncGetAll( const std::string &table_name, - const MapCallback &callback) { - RAY_CHECK(callback); + ToPostable> callback) { RedisScanner::ScanKeysAndValues(redis_client_, RedisKey{external_storage_namespace_, table_name}, RedisMatchPattern::Any(), @@ -169,41 +130,71 @@ Status RedisStoreClient::AsyncGetAll( Status RedisStoreClient::AsyncDelete(const std::string &table_name, const std::string &key, - std::function callback) { - return AsyncBatchDelete(table_name, {key}, [callback](int64_t cnt) { - if (callback != nullptr) { - callback(cnt > 0); - } - }); + Postable callback) { + return AsyncBatchDelete(table_name, {key}, std::move(callback).Compose([](int64_t cnt) { + return cnt > 0; + })); } Status RedisStoreClient::AsyncBatchDelete(const std::string &table_name, const std::vector &keys, - std::function callback) { + Postable callback) { if (keys.empty()) { - if (callback) { - callback(0); - } + callback.Post("RedisStoreClient.AsyncBatchDelete", 0); return Status::OK(); } - return DeleteByKeys(table_name, keys, callback); + return DeleteByKeys(table_name, keys, std::move(callback)); } Status RedisStoreClient::AsyncMultiGet( const std::string &table_name, const std::vector &keys, - const MapCallback &callback) { - RAY_CHECK(callback); + ToPostable> callback) { if (keys.empty()) { - callback({}); + callback.Post("RedisStoreClient.AsyncMultiGet", + absl::flat_hash_map{}); return Status::OK(); } - MGetValues(table_name, keys, callback); + // HMGET external_storage_namespace@table_name key1 key2 ... + // `keys` are chunked to multiple HMGET commands by + // RAY_maximum_gcs_storage_operation_batch_size. + + // The `HMGET` command for each shard. + auto batched_commands = GenCommandsBatched( + "HMGET", RedisKey{external_storage_namespace_, table_name}, keys); + auto total_count = batched_commands.size(); + auto finished_count = std::make_shared(0); + auto key_value_map = std::make_shared>(); + + for (auto &command : batched_commands) { + auto mget_callback = [finished_count, + total_count, + // Copies! + args = command.args, + // Copies! + callback, + key_value_map](const std::shared_ptr &reply) { + if (!reply->IsNil()) { + auto value = reply->ReadAsStringArray(); + for (size_t index = 0; index < value.size(); ++index) { + if (value[index].has_value()) { + (*key_value_map)[args[index]] = *(value[index]); + } + } + } + + ++(*finished_count); + if (*finished_count == total_count) { + callback.Post("RedisStoreClient.AsyncMultiGet", std::move(*key_value_map)); + } + }; + SendRedisCmdArgsAsKeys(std::move(command), std::move(mget_callback)); + } return Status::OK(); } size_t RedisStoreClient::PushToSendingQueue(const std::vector &keys, - std::function send_request) { + const std::function &send_request) { size_t queue_added = 0; for (const auto &key : keys) { auto [op_iter, added] = @@ -319,7 +310,7 @@ void RedisStoreClient::SendRedisCmdWithKeys(std::vector keys, Status RedisStoreClient::DeleteByKeys(const std::string &table, const std::vector &keys, - std::function callback) { + Postable callback) { auto del_cmds = GenCommandsBatched("HDEL", RedisKey{external_storage_namespace_, table}, keys); auto total_count = del_cmds.size(); @@ -332,9 +323,7 @@ Status RedisStoreClient::DeleteByKeys(const std::string &table, (*num_deleted) += reply->ReadAsInteger(); ++(*finished_count); if (*finished_count == total_count) { - if (callback) { - callback(*num_deleted); - } + callback.Post("RedisStoreClient.AsyncBatchDelete", *num_deleted); } }; SendRedisCmdArgsAsKeys(std::move(command), std::move(delete_callback)); @@ -347,7 +336,7 @@ RedisStoreClient::RedisScanner::RedisScanner( std::shared_ptr redis_client, RedisKey redis_key, RedisMatchPattern match_pattern, - MapCallback callback) + ToPostable> callback) : redis_key_(std::move(redis_key)), match_pattern_(std::move(match_pattern)), redis_client_(std::move(redis_client)), @@ -360,7 +349,7 @@ void RedisStoreClient::RedisScanner::ScanKeysAndValues( std::shared_ptr redis_client, RedisKey redis_key, RedisMatchPattern match_pattern, - MapCallback callback) { + ToPostable> callback) { auto scanner = std::make_shared(PrivateCtorTag(), std::move(redis_client), std::move(redis_key), @@ -376,7 +365,7 @@ void RedisStoreClient::RedisScanner::Scan() { // we should consider using a reader-writer lock. absl::MutexLock lock(&mutex_); if (!cursor_.has_value()) { - callback_(std::move(results_)); + callback_.Post("RedisStoreClient.RedisScanner.Scan", std::move(results_)); self_ref_.reset(); return; } @@ -446,35 +435,36 @@ int RedisStoreClient::GetNextJobID() { return static_cast(reply->ReadAsInteger()); } -Status RedisStoreClient::AsyncGetKeys( - const std::string &table_name, - const std::string &prefix, - std::function)> callback) { +Status RedisStoreClient::AsyncGetKeys(const std::string &table_name, + const std::string &prefix, + Postable)> callback) { RedisScanner::ScanKeysAndValues( redis_client_, RedisKey{external_storage_namespace_, table_name}, RedisMatchPattern::Prefix(prefix), - [callback](absl::flat_hash_map &&result) { - std::vector keys; - keys.reserve(result.size()); - for (const auto &[k, v] : result) { - keys.push_back(k); - } - callback(std::move(keys)); - }); + std::move(callback).Compose( + [](absl::flat_hash_map &&result) + -> std::vector { + std::vector keys; + keys.reserve(result.size()); + for (const auto &[k, v] : result) { + keys.push_back(k); + } + return keys; + })); return Status::OK(); } Status RedisStoreClient::AsyncExists(const std::string &table_name, const std::string &key, - std::function callback) { + Postable callback) { RedisCommand command = { "HEXISTS", RedisKey{external_storage_namespace_, table_name}, {key}}; SendRedisCmdArgsAsKeys( std::move(command), [callback = std::move(callback)](const std::shared_ptr &reply) { bool exists = reply->ReadAsInteger() > 0; - callback(exists); + callback.Post("RedisStoreClient.AsyncExists", exists); }); return Status::OK(); } diff --git a/src/ray/gcs/store_client/redis_store_client.h b/src/ray/gcs/store_client/redis_store_client.h index b9fc2643d7af..3454da037652 100644 --- a/src/ray/gcs/store_client/redis_store_client.h +++ b/src/ray/gcs/store_client/redis_store_client.h @@ -20,6 +20,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/synchronization/mutex.h" +#include "ray/common/asio/postable.h" #include "ray/common/ray_config.h" #include "ray/gcs/redis_client.h" #include "ray/gcs/redis_context.h" @@ -111,36 +112,37 @@ class RedisStoreClient : public StoreClient { const std::string &key, const std::string &data, bool overwrite, - std::function callback) override; + Postable callback) override; Status AsyncGet(const std::string &table_name, const std::string &key, - const OptionalItemCallback &callback) override; + ToPostable> callback) override; Status AsyncGetAll(const std::string &table_name, - const MapCallback &callback) override; + ToPostable> callback) override; - Status AsyncMultiGet(const std::string &table_name, - const std::vector &keys, - const MapCallback &callback) override; + Status AsyncMultiGet( + const std::string &table_name, + const std::vector &keys, + ToPostable> callback) override; Status AsyncDelete(const std::string &table_name, const std::string &key, - std::function callback) override; + Postable callback) override; Status AsyncBatchDelete(const std::string &table_name, const std::vector &keys, - std::function callback) override; + Postable callback) override; int GetNextJobID() override; Status AsyncGetKeys(const std::string &table_name, const std::string &prefix, - std::function)> callback) override; + Postable)> callback) override; Status AsyncExists(const std::string &table_name, const std::string &key, - std::function callback) override; + Postable callback) override; private: /// \class RedisScanner @@ -157,16 +159,18 @@ class RedisStoreClient : public StoreClient { struct PrivateCtorTag {}; public: + // Don't call this. Use ScanKeysAndValues instead. explicit RedisScanner(PrivateCtorTag tag, std::shared_ptr redis_client, RedisKey redis_key, RedisMatchPattern match_pattern, - MapCallback callback); + ToPostable> callback); - static void ScanKeysAndValues(std::shared_ptr redis_client, - RedisKey redis_key, - RedisMatchPattern match_pattern, - MapCallback callback); + static void ScanKeysAndValues( + std::shared_ptr redis_client, + RedisKey redis_key, + RedisMatchPattern match_pattern, + ToPostable> callback); private: // Scans the keys and values, one batch a time. Once all keys are scanned, the @@ -196,7 +200,7 @@ class RedisStoreClient : public StoreClient { std::shared_ptr redis_client_; - MapCallback callback_; + ToPostable> callback_; // Holds a self-ref until the scan is done. std::shared_ptr self_ref_; @@ -210,7 +214,7 @@ class RedisStoreClient : public StoreClient { // \return The number of queues newly added. A queue will be added // only when there is no in-flight request for the key. size_t PushToSendingQueue(const std::vector &keys, - std::function send_request) + const std::function &send_request) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Take requests from the sending queue and erase the queue if it's @@ -224,7 +228,7 @@ class RedisStoreClient : public StoreClient { Status DeleteByKeys(const std::string &table_name, const std::vector &keys, - std::function callback); + Postable callback); // Send the redis command to the server. This method will make request to be // serialized for each key in keys. At a given time, only one request for a {table_name, @@ -243,13 +247,6 @@ class RedisStoreClient : public StoreClient { // hence command.args may become empty. void SendRedisCmdArgsAsKeys(RedisCommand command, RedisCallback redis_callback); - // HMGET external_storage_namespace@table_name key1 key2 ... - // `keys` are chunked to multiple HMGET commands by - // RAY_maximum_gcs_storage_operation_batch_size. - void MGetValues(const std::string &table_name, - const std::vector &keys, - const MapCallback &callback); - std::string external_storage_namespace_; std::shared_ptr redis_client_; absl::Mutex mu_; diff --git a/src/ray/gcs/store_client/store_client.h b/src/ray/gcs/store_client/store_client.h index 23386cd95ebc..d0f7b87eddb2 100644 --- a/src/ray/gcs/store_client/store_client.h +++ b/src/ray/gcs/store_client/store_client.h @@ -18,6 +18,7 @@ #include #include "ray/common/asio/io_service_pool.h" +#include "ray/common/asio/postable.h" #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/gcs/callback.h" @@ -48,7 +49,7 @@ class StoreClient { const std::string &key, const std::string &data, bool overwrite, - std::function callback) = 0; + Postable callback) = 0; /// Get data from the given table asynchronously. /// @@ -58,15 +59,16 @@ class StoreClient { /// \return Status virtual Status AsyncGet(const std::string &table_name, const std::string &key, - const OptionalItemCallback &callback) = 0; + ToPostable> callback) = 0; /// Get all data from the given table asynchronously. /// /// \param table_name The name of the table to be read. /// \param callback returns the key value pairs in a map. /// \return Status - virtual Status AsyncGetAll(const std::string &table_name, - const MapCallback &callback) = 0; + virtual Status AsyncGetAll( + const std::string &table_name, + ToPostable> callback) = 0; /// Get all data from the given table asynchronously. /// @@ -74,9 +76,10 @@ class StoreClient { /// \param keys The keys to look up from the table. /// \param callback returns the key value pairs in a map for those keys that exist. /// \return Status - virtual Status AsyncMultiGet(const std::string &table_name, - const std::vector &keys, - const MapCallback &callback) = 0; + virtual Status AsyncMultiGet( + const std::string &table_name, + const std::vector &keys, + ToPostable> callback) = 0; /// Delete data from the given table asynchronously. /// @@ -86,7 +89,7 @@ class StoreClient { /// \return Status virtual Status AsyncDelete(const std::string &table_name, const std::string &key, - std::function callback) = 0; + Postable callback) = 0; /// Batch delete data from the given table asynchronously. /// @@ -96,7 +99,7 @@ class StoreClient { /// \return Status virtual Status AsyncBatchDelete(const std::string &table_name, const std::vector &keys, - std::function callback) = 0; + Postable callback) = 0; /// Get next job id by `INCR` "JobCounter" key synchronously. /// @@ -111,7 +114,7 @@ class StoreClient { /// \return Status virtual Status AsyncGetKeys(const std::string &table_name, const std::string &prefix, - std::function)> callback) = 0; + Postable)> callback) = 0; /// Check whether the key exists in the table. /// @@ -120,7 +123,7 @@ class StoreClient { /// \param callback Returns true if such key exists. virtual Status AsyncExists(const std::string &table_name, const std::string &key, - std::function callback) = 0; + Postable callback) = 0; protected: StoreClient() = default; From d07ffef83d447ecce8ab53bc2d13536edfeee837 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Mon, 25 Nov 2024 16:05:05 -0800 Subject: [PATCH 19/26] std function based Signed-off-by: Ruiyang Wang --- src/ray/gcs/gcs_server/gcs_actor_manager.cc | 41 +++---- src/ray/gcs/gcs_server/gcs_actor_manager.h | 3 + src/ray/gcs/gcs_server/gcs_table_storage.cc | 106 ++++++++---------- src/ray/gcs/gcs_server/store_client_kv.cc | 53 +++++---- .../gcs/store_client/redis_store_client.cc | 15 +-- 5 files changed, 102 insertions(+), 116 deletions(-) diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.cc b/src/ray/gcs/gcs_server/gcs_actor_manager.cc index 946769884a8a..69dc20a8c4cb 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.cc @@ -329,12 +329,14 @@ GcsActorManager::GcsActorManager( RuntimeEnvManager &runtime_env_manager, GcsFunctionManager &function_manager, std::function destroy_owned_placement_group_if_needed, + instrumented_io_context &io_context, const rpc::ClientFactoryFn &worker_client_factory) : gcs_actor_scheduler_(std::move(scheduler)), gcs_table_storage_(std::move(gcs_table_storage)), gcs_publisher_(std::move(gcs_publisher)), worker_client_factory_(worker_client_factory), destroy_owned_placement_group_if_needed_(destroy_owned_placement_group_if_needed), + io_context_(io_context), runtime_env_manager_(runtime_env_manager), function_manager_(function_manager), actor_gc_delay_(RayConfig::instance().gcs_actor_table_min_duration_ms()) { @@ -1110,25 +1112,26 @@ void GcsActorManager::DestroyActor(const ActorID &actor_id, RAY_CHECK_OK(gcs_table_storage_->ActorTable().Put( actor->GetActorID(), *actor_table_data, - [this, - actor, - actor_id, - actor_table_data, - is_restartable, - done_callback = std::move(done_callback)](Status status) { - if (done_callback) { - done_callback(); - } - RAY_CHECK_OK(gcs_publisher_->PublishActor( - actor_id, GenActorDataOnlyWithStates(*actor_table_data), nullptr)); - if (!is_restartable) { - RAY_CHECK_OK( - gcs_table_storage_->ActorTaskSpecTable().Delete(actor_id, nullptr)); - } - actor->WriteActorExportEvent(); - // Destroy placement group owned by this actor. - destroy_owned_placement_group_if_needed_(actor_id); - })); + {[this, + actor, + actor_id, + actor_table_data, + is_restartable, + done_callback = std::move(done_callback)](Status status) { + if (done_callback) { + done_callback(); + } + RAY_CHECK_OK(gcs_publisher_->PublishActor( + actor_id, GenActorDataOnlyWithStates(*actor_table_data), nullptr)); + if (!is_restartable) { + RAY_CHECK_OK( + gcs_table_storage_->ActorTaskSpecTable().Delete(actor_id, nullptr)); + } + actor->WriteActorExportEvent(); + // Destroy placement group owned by this actor. + destroy_owned_placement_group_if_needed_(actor_id); + }, + io_context_})); // Inform all creation callbacks that the actor was cancelled, not created. RunAndClearActorCreationCallbacks( diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.h b/src/ray/gcs/gcs_server/gcs_actor_manager.h index f31943ccb717..642a8178aaba 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.h +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.h @@ -319,6 +319,7 @@ class GcsActorManager : public rpc::ActorInfoHandler { RuntimeEnvManager &runtime_env_manager, GcsFunctionManager &function_manager, std::function destroy_owned_placement_group_if_needed, + instrumented_io_context &io_context, const rpc::ClientFactoryFn &worker_client_factory = nullptr); ~GcsActorManager() = default; @@ -697,6 +698,8 @@ class GcsActorManager : public rpc::ActorInfoHandler { /// This method MUST BE IDEMPOTENT because it can be called multiple times during /// actor destroy process. std::function destroy_owned_placement_group_if_needed_; + /// The io context for the InternalKV callbacks. + instrumented_io_context &io_context_; /// Runtime environment manager for GC purpose RuntimeEnvManager &runtime_env_manager_; /// Function manager for GC purpose diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.cc b/src/ray/gcs/gcs_server/gcs_table_storage.cc index 7464aeeb8059..7e78923581f6 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.cc +++ b/src/ray/gcs/gcs_server/gcs_table_storage.cc @@ -30,65 +30,53 @@ Status GcsTable::Put(const Key &key, key.Binary(), value.SerializeAsString(), /*overwrite*/ true, - std::move(callback).Compose([](bool) { return Status::OK(); })); + std::move(callback).TransformArg([](bool) { return Status::OK(); })); } template Status GcsTable::Get(const Key &key, ToPostable> callback) { - auto on_done = [callback](const Status &status, - const std::optional &result) { - if (!callback) { - return; - } - std::optional value; - if (result) { - Data data; - data.ParseFromString(*result); - value = std::move(data); - } - callback(status, std::move(value)); - }; return store_client_->AsyncGet( table_name_, key.Binary(), - std::move(callback).Compose( - [](const Status &status, const std::optional &result) { - std::optional value; - if (result) { - Data data; - data.ParseFromString(*result); - value = std::move(data); - } - re + std::move(callback).Rebind &&)>( + [](auto cb) { + return + [cb = std::move(cb)](Status status, std::optional &&result) { + std::optional value; + if (result) { + Data data; + data.ParseFromString(*result); + value = std::move(data); + } + cb(status, std::move(value)); + }; })); } template Status GcsTable::GetAll(ToPostable> callback) { - auto on_done = [callback](absl::flat_hash_map &&result) { - if (!callback) { - return; - } - absl::flat_hash_map values; - values.reserve(result.size()); - for (auto &item : result) { - if (!item.second.empty()) { - values[Key::FromBinary(item.first)].ParseFromString(item.second); - } - } - callback(std::move(values)); - }; - return store_client_->AsyncGetAll(table_name_, on_done); + return store_client_->AsyncGetAll( + table_name_, + std::move(callback).TransformArg &&>( + [](absl::flat_hash_map &&result) { + absl::flat_hash_map values; + values.reserve(result.size()); + for (auto &item : result) { + if (!item.second.empty()) { + values[Key::FromBinary(item.first)].ParseFromString(item.second); + } + } + return std::move(values); + })); } template Status GcsTable::Delete(const Key &key, Postable callback) { - return store_client_->AsyncDelete(table_name_, key.Binary(), [callback](auto) { - if (callback) { - callback(Status::OK()); - } - }); + return store_client_->AsyncDelete( + table_name_, key.Binary(), std::move(callback).TransformArg([](bool) { + return Status::OK(); + })); } template @@ -100,11 +88,9 @@ Status GcsTable::BatchDelete(const std::vector &keys, keys_to_delete.emplace_back(std::move(key.Binary())); } return this->store_client_->AsyncBatchDelete( - this->table_name_, keys_to_delete, [callback](auto) { - if (callback) { - callback(Status::OK()); - } - }); + this->table_name_, + keys_to_delete, + std::move(callback).TransformArg([](int64_t) { return Status::OK(); })); } template @@ -115,16 +101,12 @@ Status GcsTableWithJobId::Put(const Key &key, absl::MutexLock lock(&mutex_); index_[GetJobIdFromKey(key)].insert(key); } - return this->store_client_->AsyncPut(this->table_name_, - key.Binary(), - value.SerializeAsString(), - /*overwrite*/ true, - [callback](auto) { - if (!callback) { - return; - } - callback(Status::OK()); - }); + return this->store_client_->AsyncPut( + this->table_name_, + key.Binary(), + value.SerializeAsString(), + /*overwrite*/ true, + std::move(callback).TransformArg([](bool) { return Status::OK(); })); } template @@ -182,17 +164,17 @@ Status GcsTableWithJobId::BatchDelete(const std::vector &keys, keys_to_delete.push_back(key.Binary()); } return this->store_client_->AsyncBatchDelete( - this->table_name_, keys_to_delete, [this, callback, keys](auto) { + this->table_name_, + keys_to_delete, + std::move(callback).TransformArg([this, keys](auto) { { absl::MutexLock lock(&mutex_); for (auto &key : keys) { index_[GetJobIdFromKey(key)].erase(key); } } - if (callback) { - callback(Status::OK()); - } - }); + return Status::OK(); + })); } template diff --git a/src/ray/gcs/gcs_server/store_client_kv.cc b/src/ray/gcs/gcs_server/store_client_kv.cc index 95e3670f8148..4bbebcf5cdf1 100644 --- a/src/ray/gcs/gcs_server/store_client_kv.cc +++ b/src/ray/gcs/gcs_server/store_client_kv.cc @@ -58,10 +58,9 @@ void StoreClientInternalKV::Get(const std::string &ns, table_name_, MakeKey(ns, key), [callback = std::move(callback)](auto status, auto result) { - callback.PostIfNonNull("StoreClientInternalKV::Get", - result.has_value() - ? std::optional(result.value()) - : std::optional()); + callback.Post("StoreClientInternalKV::Get", + result.has_value() ? std::optional(result.value()) + : std::optional()); })); } @@ -80,7 +79,7 @@ void StoreClientInternalKV::MultiGet( for (const auto &item : result) { ret.emplace(ExtractKey(item.first), item.second); } - callback.PostIfNonNull("StoreClientInternalKV::MultiGet", std::move(ret)); + callback.Post("StoreClientInternalKV::MultiGet", std::move(ret)); })); } @@ -94,8 +93,7 @@ void StoreClientInternalKV::Put(const std::string &ns, value, overwrite, [callback = std::move(callback)](bool success) { - callback.PostIfNonNull("StoreClientInternalKV::Put", - success); + callback.Post("StoreClientInternalKV::Put", success); })); } @@ -103,38 +101,37 @@ void StoreClientInternalKV::Del(const std::string &ns, const std::string &key, bool del_by_prefix, Postable callback) { - auto dispatch_and_call = - std::move(callback).AsDispatchedFunction("StoreClientInternalKV::Del"); if (!del_by_prefix) { RAY_CHECK_OK(delegate_->AsyncDelete( - table_name_, - MakeKey(ns, key), - [dispatch_and_call = std::move(dispatch_and_call)](bool deleted) { - dispatch_and_call(deleted ? 1 : 0); - })); + table_name_, MakeKey(ns, key), std::move(callback).TransformArg([](bool deleted) { + return deleted ? 1 : 0; + }))); return; } - + // This one requires 2 async calls, so we can't just do `Rebind`, instead we need to + // manually use the io context. RAY_CHECK_OK(delegate_->AsyncGetKeys( table_name_, MakeKey(ns, key), - [this, ns, dispatch_and_call = std::move(dispatch_and_call)](auto keys) { - if (keys.empty()) { - dispatch_and_call(0); - return; - } - RAY_CHECK_OK( - delegate_->AsyncBatchDelete(table_name_, keys, std::move(dispatch_and_call))); - })); + Postable)>( + [this, ns, callback = std::move(callback)]( + const std::vector &keys) -> void { + if (keys.empty()) { + // We are directly calling this because we + // don't need another Post. + callback.func_(0); + return; + } + RAY_CHECK_OK(delegate_->AsyncBatchDelete(table_name_, keys, callback)); + }, + callback.io_context_))); } void StoreClientInternalKV::Exists(const std::string &ns, const std::string &key, Postable callback) { - RAY_CHECK_OK(delegate_->AsyncExists( - table_name_, - MakeKey(ns, key), - std::move(callback).AsDispatchedFunction("StoreClientInternalKV::Exists"))); + RAY_CHECK_OK( + delegate_->AsyncExists(table_name_, MakeKey(ns, key), std::move(callback))); } void StoreClientInternalKV::Keys(const std::string &ns, @@ -143,7 +140,7 @@ void StoreClientInternalKV::Keys(const std::string &ns, RAY_CHECK_OK(delegate_->AsyncGetKeys( table_name_, MakeKey(ns, prefix), - std::move(callback).Compose( + std::move(callback).TransformArg( [](const std::vector &keys) -> std::vector { std::vector true_keys; true_keys.reserve(keys.size()); diff --git a/src/ray/gcs/store_client/redis_store_client.cc b/src/ray/gcs/store_client/redis_store_client.cc index 634ec06ce782..f715692fe827 100644 --- a/src/ray/gcs/store_client/redis_store_client.cc +++ b/src/ray/gcs/store_client/redis_store_client.cc @@ -131,9 +131,10 @@ Status RedisStoreClient::AsyncGetAll( Status RedisStoreClient::AsyncDelete(const std::string &table_name, const std::string &key, Postable callback) { - return AsyncBatchDelete(table_name, {key}, std::move(callback).Compose([](int64_t cnt) { - return cnt > 0; - })); + return AsyncBatchDelete(table_name, + {key}, + std::move(callback).TransformArg( + std::function{[](int64_t cnt) { return cnt > 0; }})); } Status RedisStoreClient::AsyncBatchDelete(const std::string &table_name, @@ -442,16 +443,16 @@ Status RedisStoreClient::AsyncGetKeys(const std::string &table_name, redis_client_, RedisKey{external_storage_namespace_, table_name}, RedisMatchPattern::Prefix(prefix), - std::move(callback).Compose( - [](absl::flat_hash_map &&result) - -> std::vector { + std::move(callback).TransformArg( + std::function{[](absl::flat_hash_map &&result) + -> std::vector { std::vector keys; keys.reserve(result.size()); for (const auto &[k, v] : result) { keys.push_back(k); } return keys; - })); + }})); return Status::OK(); } From e6c9df976d6f7808f872fc450fb2a5c0cff94311 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Mon, 25 Nov 2024 21:57:29 -0800 Subject: [PATCH 20/26] Postable for all Signed-off-by: Ruiyang Wang --- src/ray/gcs/gcs_server/gcs_actor_manager.cc | 233 +++++++++--------- src/ray/gcs/gcs_server/gcs_actor_scheduler.cc | 19 +- src/ray/gcs/gcs_server/gcs_function_manager.h | 13 +- src/ray/gcs/gcs_server/gcs_init_data.cc | 52 ++-- src/ray/gcs/gcs_server/gcs_init_data.h | 17 +- src/ray/gcs/gcs_server/gcs_job_manager.cc | 48 ++-- src/ray/gcs/gcs_server/gcs_node_manager.cc | 21 +- src/ray/gcs/gcs_server/gcs_node_manager.h | 3 + .../gcs_server/gcs_placement_group_manager.cc | 139 ++++++----- .../gcs_placement_group_scheduler.cc | 13 +- src/ray/gcs/gcs_server/gcs_server.cc | 41 +-- src/ray/gcs/gcs_server/gcs_table_storage.cc | 100 ++++---- src/ray/gcs/gcs_server/gcs_table_storage.h | 12 +- src/ray/gcs/gcs_server/gcs_worker_manager.cc | 20 +- src/ray/gcs/gcs_server/gcs_worker_manager.h | 8 +- src/ray/gcs/gcs_server/store_client_kv.cc | 56 ++--- .../store_client/in_memory_store_client.cc | 4 +- .../gcs/store_client/in_memory_store_client.h | 7 +- .../store_client/observable_store_client.cc | 4 +- .../store_client/observable_store_client.h | 7 +- .../gcs/store_client/redis_store_client.cc | 41 ++- src/ray/gcs/store_client/redis_store_client.h | 22 +- src/ray/gcs/store_client/store_client.h | 4 +- 23 files changed, 472 insertions(+), 412 deletions(-) diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.cc b/src/ray/gcs/gcs_server/gcs_actor_manager.cc index 69dc20a8c4cb..a6be76f630cb 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.cc @@ -602,34 +602,35 @@ void GcsActorManager::HandleGetAllActorInfo(rpc::GetAllActorInfoRequest request, // We don't maintain an in-memory cache of all actors which belong to dead // jobs, so fetch it from redis. Status status = gcs_table_storage_->ActorTable().GetAll( - [reply, send_reply_callback, limit, request = std::move(request), filter_fn]( - absl::flat_hash_map &&result) { - auto total_actors = result.size(); - - reply->set_total(total_actors); - auto arena = reply->GetArena(); - RAY_CHECK(arena != nullptr); - auto ptr = google::protobuf::Arena::Create< - absl::flat_hash_map>(arena, std::move(result)); - size_t count = 0; - size_t num_filtered = 0; - for (auto &pair : *ptr) { - if (count >= limit) { - break; - } - // With filters, skip the actor if it doesn't match the filter. - if (request.has_filters() && !filter_fn(request.filters(), pair.second)) { - ++num_filtered; - continue; - } - count += 1; - - reply->mutable_actor_table_data()->UnsafeArenaAddAllocated(&pair.second); - } - reply->set_num_filtered(num_filtered); - GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); - RAY_LOG(DEBUG) << "Finished getting all actor info."; - }); + {[reply, send_reply_callback, limit, request = std::move(request), filter_fn]( + absl::flat_hash_map &&result) { + auto total_actors = result.size(); + + reply->set_total(total_actors); + auto arena = reply->GetArena(); + RAY_CHECK(arena != nullptr); + auto ptr = google::protobuf::Arena::Create< + absl::flat_hash_map>(arena, std::move(result)); + size_t count = 0; + size_t num_filtered = 0; + for (auto &pair : *ptr) { + if (count >= limit) { + break; + } + // With filters, skip the actor if it doesn't match the filter. + if (request.has_filters() && !filter_fn(request.filters(), pair.second)) { + ++num_filtered; + continue; + } + count += 1; + + reply->mutable_actor_table_data()->UnsafeArenaAddAllocated(&pair.second); + } + reply->set_num_filtered(num_filtered); + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); + RAY_LOG(DEBUG) << "Finished getting all actor info."; + }, + io_context_}); if (!status.ok()) { // Send the response to unblock the sender and free the request. GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); @@ -793,46 +794,49 @@ Status GcsActorManager::RegisterActor(const ray::rpc::RegisterActorRequest &requ RAY_CHECK_OK(gcs_table_storage_->ActorTaskSpecTable().Put( actor_id, request.task_spec(), - [this, actor, register_callback](const Status &status) { - RAY_CHECK_OK(gcs_table_storage_->ActorTable().Put( - actor->GetActorID(), - *actor->GetMutableActorTableData(), - [this, actor, register_callback](const Status &status) { - // The backend storage is supposed to be reliable, so the status must be ok. - RAY_CHECK_OK(status); - actor->WriteActorExportEvent(); - auto registered_actor_it = registered_actors_.find(actor->GetActorID()); - auto reply_status = Status::OK(); - if (registered_actor_it == registered_actors_.end()) { - // NOTE(sang): This logic assumes that the ordering of backend call is - // guaranteed. It is currently true because we use a single TCP socket to - // call the default Redis backend. If ordering is not guaranteed, we - // should overwrite the actor state to DEAD to avoid race condition. - RAY_LOG(INFO).WithField(actor->GetActorID()) - << "Actor is killed before dependency is prepared."; - RAY_CHECK(actor_to_register_callbacks_.find(actor->GetActorID()) == - actor_to_register_callbacks_.end()); - register_callback( - actor, Status::SchedulingCancelled("Actor creation cancelled.")); - return; - } - - RAY_CHECK_OK(gcs_publisher_->PublishActor( - actor->GetActorID(), actor->GetActorTableData(), nullptr)); - // Invoke all callbacks for all registration requests of this actor - // (duplicated requests are included) and remove all of them from - // actor_to_register_callbacks_. - // Reply to the owner to indicate that the actor has been registered. - auto iter = actor_to_register_callbacks_.find(actor->GetActorID()); - RAY_CHECK(iter != actor_to_register_callbacks_.end() && - !iter->second.empty()); - auto callbacks = std::move(iter->second); - actor_to_register_callbacks_.erase(iter); - for (auto &callback : callbacks) { - callback(actor, Status::OK()); - } - })); - })); + {[this, actor, register_callback](const Status &status) { + RAY_CHECK_OK(gcs_table_storage_->ActorTable().Put( + actor->GetActorID(), + *actor->GetMutableActorTableData(), + {[this, actor, register_callback](const Status &status) { + // The backend storage is supposed to be reliable, so the status must be + // ok. + RAY_CHECK_OK(status); + actor->WriteActorExportEvent(); + auto registered_actor_it = registered_actors_.find(actor->GetActorID()); + auto reply_status = Status::OK(); + if (registered_actor_it == registered_actors_.end()) { + // NOTE(sang): This logic assumes that the ordering of backend call is + // guaranteed. It is currently true because we use a single TCP socket + // to call the default Redis backend. If ordering is not guaranteed, we + // should overwrite the actor state to DEAD to avoid race condition. + RAY_LOG(INFO).WithField(actor->GetActorID()) + << "Actor is killed before dependency is prepared."; + RAY_CHECK(actor_to_register_callbacks_.find(actor->GetActorID()) == + actor_to_register_callbacks_.end()); + register_callback( + actor, Status::SchedulingCancelled("Actor creation cancelled.")); + return; + } + + RAY_CHECK_OK(gcs_publisher_->PublishActor( + actor->GetActorID(), actor->GetActorTableData(), nullptr)); + // Invoke all callbacks for all registration requests of this actor + // (duplicated requests are included) and remove all of them from + // actor_to_register_callbacks_. + // Reply to the owner to indicate that the actor has been registered. + auto iter = actor_to_register_callbacks_.find(actor->GetActorID()); + RAY_CHECK(iter != actor_to_register_callbacks_.end() && + !iter->second.empty()); + auto callbacks = std::move(iter->second); + actor_to_register_callbacks_.erase(iter); + for (auto &callback : callbacks) { + callback(actor, Status::OK()); + } + }, + io_context_})); + }, + io_context_})); return Status::OK(); } @@ -1124,8 +1128,8 @@ void GcsActorManager::DestroyActor(const ActorID &actor_id, RAY_CHECK_OK(gcs_publisher_->PublishActor( actor_id, GenActorDataOnlyWithStates(*actor_table_data), nullptr)); if (!is_restartable) { - RAY_CHECK_OK( - gcs_table_storage_->ActorTaskSpecTable().Delete(actor_id, nullptr)); + RAY_CHECK_OK(gcs_table_storage_->ActorTaskSpecTable().Delete( + actor_id, {[](Status) {}, io_context_})); } actor->WriteActorExportEvent(); // Destroy placement group owned by this actor. @@ -1354,10 +1358,13 @@ void GcsActorManager::SetPreemptedAndPublish(const NodeID &node_id) { const auto &actor_table_data = actor_iter->second->GetActorTableData(); RAY_CHECK_OK(gcs_table_storage_->ActorTable().Put( - actor_id, actor_table_data, [this, actor_id, actor_table_data](Status status) { - RAY_CHECK_OK(gcs_publisher_->PublishActor( - actor_id, GenActorDataOnlyWithStates(actor_table_data), nullptr)); - })); + actor_id, + actor_table_data, + {[this, actor_id, actor_table_data](Status status) { + RAY_CHECK_OK(gcs_publisher_->PublishActor( + actor_id, GenActorDataOnlyWithStates(actor_table_data), nullptr)); + }, + io_context_})); } } @@ -1416,14 +1423,15 @@ void GcsActorManager::RestartActor(const ActorID &actor_id, RAY_CHECK_OK(gcs_table_storage_->ActorTable().Put( actor_id, *mutable_actor_table_data, - [this, actor, actor_id, mutable_actor_table_data, done_callback](Status status) { - if (done_callback) { - done_callback(); - } - RAY_CHECK_OK(gcs_publisher_->PublishActor( - actor_id, GenActorDataOnlyWithStates(*mutable_actor_table_data), nullptr)); - actor->WriteActorExportEvent(); - })); + {[this, actor, actor_id, mutable_actor_table_data, done_callback](Status status) { + if (done_callback) { + done_callback(); + } + RAY_CHECK_OK(gcs_publisher_->PublishActor( + actor_id, GenActorDataOnlyWithStates(*mutable_actor_table_data), nullptr)); + actor->WriteActorExportEvent(); + }, + io_context_})); gcs_actor_scheduler_->Schedule(actor); } else { RemoveActorNameFromRegistry(actor); @@ -1437,23 +1445,24 @@ void GcsActorManager::RestartActor(const ActorID &actor_id, RAY_CHECK_OK(gcs_table_storage_->ActorTable().Put( actor_id, *mutable_actor_table_data, - [this, actor, actor_id, mutable_actor_table_data, death_cause, done_callback]( - Status status) { - // If actor was an detached actor, make sure to destroy it. - // We need to do this because detached actors are not destroyed - // when its owners are dead because it doesn't have owners. - if (actor->IsDetached()) { - DestroyActor(actor_id, death_cause); - } - if (done_callback) { - done_callback(); - } - RAY_CHECK_OK(gcs_publisher_->PublishActor( - actor_id, GenActorDataOnlyWithStates(*mutable_actor_table_data), nullptr)); - RAY_CHECK_OK( - gcs_table_storage_->ActorTaskSpecTable().Delete(actor_id, nullptr)); - actor->WriteActorExportEvent(); - })); + {[this, actor, actor_id, mutable_actor_table_data, death_cause, done_callback]( + Status status) { + // If actor was an detached actor, make sure to destroy it. + // We need to do this because detached actors are not destroyed + // when its owners are dead because it doesn't have owners. + if (actor->IsDetached()) { + DestroyActor(actor_id, death_cause); + } + if (done_callback) { + done_callback(); + } + RAY_CHECK_OK(gcs_publisher_->PublishActor( + actor_id, GenActorDataOnlyWithStates(*mutable_actor_table_data), nullptr)); + RAY_CHECK_OK(gcs_table_storage_->ActorTaskSpecTable().Delete( + actor_id, {[](Status) {}, io_context_})); + actor->WriteActorExportEvent(); + }, + io_context_})); // The actor is dead, but we should not remove the entry from the // registered actors yet. If the actor is owned, we will destroy the actor // once the owner fails or notifies us that the actor has no references. @@ -1557,15 +1566,16 @@ void GcsActorManager::OnActorCreationSuccess(const std::shared_ptr &ac RAY_CHECK_OK(gcs_table_storage_->ActorTable().Put( actor_id, actor_table_data, - [this, actor_id, actor_table_data, actor, reply](Status status) { - RAY_CHECK_OK(gcs_publisher_->PublishActor( - actor_id, GenActorDataOnlyWithStates(actor_table_data), nullptr)); - actor->WriteActorExportEvent(); - // Invoke all callbacks for all registration requests of this actor (duplicated - // requests are included) and remove all of them from - // actor_to_create_callbacks_. - RunAndClearActorCreationCallbacks(actor, reply, Status::OK()); - })); + {[this, actor_id, actor_table_data, actor, reply](Status status) { + RAY_CHECK_OK(gcs_publisher_->PublishActor( + actor_id, GenActorDataOnlyWithStates(actor_table_data), nullptr)); + actor->WriteActorExportEvent(); + // Invoke all callbacks for all registration requests of this actor (duplicated + // requests are included) and remove all of them from + // actor_to_create_callbacks_. + RunAndClearActorCreationCallbacks(actor, reply, Status::OK()); + }, + io_context_})); } void GcsActorManager::SchedulePendingActors() { @@ -1631,8 +1641,8 @@ void GcsActorManager::Initialize(const GcsInitData &gcs_init_data) { } } if (!dead_actors.empty()) { - RAY_CHECK_OK( - gcs_table_storage_->ActorTaskSpecTable().BatchDelete(dead_actors, nullptr)); + RAY_CHECK_OK(gcs_table_storage_->ActorTaskSpecTable().BatchDelete( + dead_actors, {[](Status) {}, io_context_})); } sorted_destroyed_actor_list_.sort([](const std::pair &left, const std::pair &right) { @@ -1774,7 +1784,8 @@ void GcsActorManager::AddDestroyedActorToCache(const std::shared_ptr & if (destroyed_actors_.size() >= RayConfig::instance().maximum_gcs_destroyed_actor_cached_count()) { const auto &actor_id = sorted_destroyed_actor_list_.front().first; - RAY_CHECK_OK(gcs_table_storage_->ActorTable().Delete(actor_id, nullptr)); + RAY_CHECK_OK( + gcs_table_storage_->ActorTable().Delete(actor_id, {[](Status) {}, io_context_})); destroyed_actors_.erase(actor_id); sorted_destroyed_actor_list_.pop_front(); } diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc index 3582152c5f66..7e167f30ba3c 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc @@ -414,15 +414,16 @@ void GcsActorScheduler::HandleWorkerLeaseGrantedReply( core_worker_clients_.GetOrConnect(leased_worker->GetAddress()); RAY_CHECK_OK(gcs_actor_table_.Put(actor->GetActorID(), actor->GetActorTableData(), - [this, actor, leased_worker](Status status) { - RAY_CHECK_OK(status); - if (actor->GetState() == - rpc::ActorTableData::DEAD) { - // Actor has already been killed. - return; - } - CreateActorOnWorker(actor, leased_worker); - })); + {[this, actor, leased_worker](Status status) { + RAY_CHECK_OK(status); + if (actor->GetState() == + rpc::ActorTableData::DEAD) { + // Actor has already been killed. + return; + } + CreateActorOnWorker(actor, leased_worker); + }, + io_context_})); } } diff --git a/src/ray/gcs/gcs_server/gcs_function_manager.h b/src/ray/gcs/gcs_server/gcs_function_manager.h index ac05c98671f1..72da87c3ab0f 100644 --- a/src/ray/gcs/gcs_server/gcs_function_manager.h +++ b/src/ray/gcs/gcs_server/gcs_function_manager.h @@ -29,7 +29,9 @@ namespace gcs { /// - function/actor code life cycle management. class GcsFunctionManager { public: - explicit GcsFunctionManager(InternalKVInterface &kv) : kv_(kv) {} + explicit GcsFunctionManager(InternalKVInterface &kv, + instrumented_io_context &io_context) + : kv_(kv), io_context_(io_context) {} void AddJobReference(const JobID &job_id) { job_counter_[job_id]++; } @@ -45,17 +47,20 @@ class GcsFunctionManager { private: void RemoveExportedFunctions(const JobID &job_id) { + Postable no_op([](int64_t) {}, io_context_); + auto job_id_hex = job_id.Hex(); - kv_.Del("fun", "RemoteFunction:" + job_id_hex + ":", true, Postable{}); - kv_.Del("fun", "ActorClass:" + job_id_hex + ":", true, Postable{}); + kv_.Del("fun", "RemoteFunction:" + job_id_hex + ":", true, no_op); + kv_.Del("fun", "ActorClass:" + job_id_hex + ":", true, no_op); kv_.Del("fun", std::string(kWorkerSetupHookKeyName) + ":" + job_id_hex + ":", true, - Postable{}); + no_op); } // Handler for internal KV InternalKVInterface &kv_; + instrumented_io_context &io_context_; // Counter to check whether the job has finished or not. // A job is defined to be in finished status if diff --git a/src/ray/gcs/gcs_server/gcs_init_data.cc b/src/ray/gcs/gcs_server/gcs_init_data.cc index 60aceeb4178c..e5be9ab6175f 100644 --- a/src/ray/gcs/gcs_server/gcs_init_data.cc +++ b/src/ray/gcs/gcs_server/gcs_init_data.cc @@ -16,7 +16,8 @@ namespace ray { namespace gcs { -void GcsInitData::AsyncLoad(const EmptyCallback &on_done) { +void GcsInitData::AsyncLoad(const EmptyCallback &on_done, + instrumented_io_context &io_context) { // There are 5 kinds of table data need to be loaded. auto count_down = std::make_shared(5); auto on_load_finished = [count_down, on_done] { @@ -27,80 +28,87 @@ void GcsInitData::AsyncLoad(const EmptyCallback &on_done) { } }; - AsyncLoadJobTableData(on_load_finished); + AsyncLoadJobTableData(on_load_finished, io_context); - AsyncLoadNodeTableData(on_load_finished); + AsyncLoadNodeTableData(on_load_finished, io_context); - AsyncLoadActorTableData(on_load_finished); + AsyncLoadActorTableData(on_load_finished, io_context); - AsyncLoadActorTaskSpecTableData(on_load_finished); + AsyncLoadActorTaskSpecTableData(on_load_finished, io_context); - AsyncLoadPlacementGroupTableData(on_load_finished); + AsyncLoadPlacementGroupTableData(on_load_finished, io_context); } -void GcsInitData::AsyncLoadJobTableData(const EmptyCallback &on_done) { +void GcsInitData::AsyncLoadJobTableData(const EmptyCallback &on_done, + instrumented_io_context &io_context) { RAY_LOG(INFO) << "Loading job table data."; auto load_job_table_data_callback = - [this, on_done](absl::flat_hash_map &&result) { + [this, on_done](absl::flat_hash_map result) { job_table_data_ = std::move(result); RAY_LOG(INFO) << "Finished loading job table data, size = " << job_table_data_.size(); on_done(); }; - RAY_CHECK_OK(gcs_table_storage_->JobTable().GetAll(load_job_table_data_callback)); + RAY_CHECK_OK( + gcs_table_storage_->JobTable().GetAll({load_job_table_data_callback, io_context})); } -void GcsInitData::AsyncLoadNodeTableData(const EmptyCallback &on_done) { +void GcsInitData::AsyncLoadNodeTableData(const EmptyCallback &on_done, + instrumented_io_context &io_context) { RAY_LOG(INFO) << "Loading node table data."; auto load_node_table_data_callback = - [this, on_done](absl::flat_hash_map &&result) { + [this, on_done](absl::flat_hash_map result) { node_table_data_ = std::move(result); RAY_LOG(INFO) << "Finished loading node table data, size = " << node_table_data_.size(); on_done(); }; - RAY_CHECK_OK(gcs_table_storage_->NodeTable().GetAll(load_node_table_data_callback)); + RAY_CHECK_OK(gcs_table_storage_->NodeTable().GetAll( + {load_node_table_data_callback, io_context})); } -void GcsInitData::AsyncLoadPlacementGroupTableData(const EmptyCallback &on_done) { +void GcsInitData::AsyncLoadPlacementGroupTableData(const EmptyCallback &on_done, + instrumented_io_context &io_context) { RAY_LOG(INFO) << "Loading placement group table data."; auto load_placement_group_table_data_callback = [this, on_done]( - absl::flat_hash_map &&result) { + absl::flat_hash_map result) { placement_group_table_data_ = std::move(result); RAY_LOG(INFO) << "Finished loading placement group table data, size = " << placement_group_table_data_.size(); on_done(); }; RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().GetAll( - load_placement_group_table_data_callback)); + {load_placement_group_table_data_callback, io_context})); } -void GcsInitData::AsyncLoadActorTableData(const EmptyCallback &on_done) { +void GcsInitData::AsyncLoadActorTableData(const EmptyCallback &on_done, + instrumented_io_context &io_context) { RAY_LOG(INFO) << "Loading actor table data."; auto load_actor_table_data_callback = - [this, on_done](absl::flat_hash_map &&result) { + [this, on_done](absl::flat_hash_map result) { actor_table_data_ = std::move(result); RAY_LOG(INFO) << "Finished loading actor table data, size = " << actor_table_data_.size(); on_done(); }; RAY_CHECK_OK(gcs_table_storage_->ActorTable().AsyncRebuildIndexAndGetAll( - load_actor_table_data_callback)); + {load_actor_table_data_callback, io_context})); } -void GcsInitData::AsyncLoadActorTaskSpecTableData(const EmptyCallback &on_done) { +void GcsInitData::AsyncLoadActorTaskSpecTableData(const EmptyCallback &on_done, + instrumented_io_context &io_context) { RAY_LOG(INFO) << "Loading actor task spec table data."; auto load_actor_task_spec_table_data_callback = - [this, on_done](const absl::flat_hash_map &result) { + [this, on_done](absl::flat_hash_map result) { actor_task_spec_table_data_ = std::move(result); RAY_LOG(INFO) << "Finished loading actor task spec table data, size = " << actor_task_spec_table_data_.size(); on_done(); }; RAY_CHECK_OK(gcs_table_storage_->ActorTaskSpecTable().GetAll( - load_actor_task_spec_table_data_callback)); + {load_actor_task_spec_table_data_callback, io_context})); } } // namespace gcs -} // namespace ray \ No newline at end of file +} // namespace ray diff --git a/src/ray/gcs/gcs_server/gcs_init_data.h b/src/ray/gcs/gcs_server/gcs_init_data.h index 7ee0df9da447..e1cb20b0bdbe 100644 --- a/src/ray/gcs/gcs_server/gcs_init_data.h +++ b/src/ray/gcs/gcs_server/gcs_init_data.h @@ -37,7 +37,7 @@ class GcsInitData { /// Load all required metadata from the store into memory at once asynchronously. /// /// \param on_done The callback when all metadatas are loaded successfully. - void AsyncLoad(const EmptyCallback &on_done); + void AsyncLoad(const EmptyCallback &on_done, instrumented_io_context &io_context); /// Get job metadata. const absl::flat_hash_map &Jobs() const { @@ -68,24 +68,29 @@ class GcsInitData { /// Load job metadata from the store into memory asynchronously. /// /// \param on_done The callback when job metadata is loaded successfully. - void AsyncLoadJobTableData(const EmptyCallback &on_done); + void AsyncLoadJobTableData(const EmptyCallback &on_done, + instrumented_io_context &io_context); /// Load node metadata from the store into memory asynchronously. /// /// \param on_done The callback when node metadata is loaded successfully. - void AsyncLoadNodeTableData(const EmptyCallback &on_done); + void AsyncLoadNodeTableData(const EmptyCallback &on_done, + instrumented_io_context &io_context); /// Load placement group metadata from the store into memory asynchronously. /// /// \param on_done The callback when placement group metadata is loaded successfully. - void AsyncLoadPlacementGroupTableData(const EmptyCallback &on_done); + void AsyncLoadPlacementGroupTableData(const EmptyCallback &on_done, + instrumented_io_context &io_context); /// Load actor metadata from the store into memory asynchronously. /// /// \param on_done The callback when actor metadata is loaded successfully. - void AsyncLoadActorTableData(const EmptyCallback &on_done); + void AsyncLoadActorTableData(const EmptyCallback &on_done, + instrumented_io_context &io_context); - void AsyncLoadActorTaskSpecTableData(const EmptyCallback &on_done); + void AsyncLoadActorTaskSpecTableData(const EmptyCallback &on_done, + instrumented_io_context &io_context); protected: /// The gcs table storage. diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.cc b/src/ray/gcs/gcs_server/gcs_job_manager.cc index 32b34c4b9745..fbeb23bd9331 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_job_manager.cc @@ -121,8 +121,8 @@ void GcsJobManager::HandleAddJob(rpc::AddJobRequest request, GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); }; - Status status = - gcs_table_storage_->JobTable().Put(job_id, mutable_job_table_data, on_done); + Status status = gcs_table_storage_->JobTable().Put( + job_id, mutable_job_table_data, {on_done, io_context_}); if (!status.ok()) { on_done(status); } @@ -160,7 +160,8 @@ void GcsJobManager::MarkJobAsFinished(rpc::JobTableData job_table_data, done_callback(status); }; - Status status = gcs_table_storage_->JobTable().Put(job_id, job_table_data, on_done); + Status status = + gcs_table_storage_->JobTable().Put(job_id, job_table_data, {on_done, io_context_}); if (!status.ok()) { on_done(status); } @@ -178,24 +179,25 @@ void GcsJobManager::HandleMarkJobFinished(rpc::MarkJobFinishedRequest request, Status status = gcs_table_storage_->JobTable().Get( job_id, - [this, job_id, send_reply](const Status &status, - const std::optional &result) { - RAY_CHECK(thread_checker_.IsOnSameThread()); - - if (status.ok() && result) { - MarkJobAsFinished(*result, send_reply); - return; - } - - if (!result.has_value()) { - RAY_LOG(ERROR) << "Tried to mark job " << job_id - << " as finished, but there was no record of it starting!"; - } else if (!status.ok()) { - RAY_LOG(ERROR) << "Fails to mark job " << job_id << " as finished due to " - << status; - } - send_reply(status); - }); + {[this, job_id, send_reply](const Status &status, + const std::optional &result) { + RAY_CHECK(thread_checker_.IsOnSameThread()); + + if (status.ok() && result) { + MarkJobAsFinished(*result, send_reply); + return; + } + + if (!result.has_value()) { + RAY_LOG(ERROR) << "Tried to mark job " << job_id + << " as finished, but there was no record of it starting!"; + } else if (!status.ok()) { + RAY_LOG(ERROR) << "Fails to mark job " << job_id << " as finished due to " + << status; + } + send_reply(status); + }, + io_context_}); if (!status.ok()) { send_reply(status); } @@ -424,7 +426,7 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, "job", job_api_data_keys, {std::move(kv_multi_get_callback), io_context_}); } }; - Status status = gcs_table_storage_->JobTable().GetAll(on_done); + Status status = gcs_table_storage_->JobTable().GetAll({on_done, io_context_}); if (!status.ok()) { on_done(absl::flat_hash_map()); } @@ -473,7 +475,7 @@ void GcsJobManager::OnNodeDead(const NodeID &node_id) { }; // make all jobs in current node to finished - RAY_CHECK_OK(gcs_table_storage_->JobTable().GetAll(on_done)); + RAY_CHECK_OK(gcs_table_storage_->JobTable().GetAll({on_done, io_context_})); } void GcsJobManager::RecordMetrics() { diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.cc b/src/ray/gcs/gcs_server/gcs_node_manager.cc index 71993125f913..c6b86c72becb 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_node_manager.cc @@ -33,10 +33,12 @@ GcsNodeManager::GcsNodeManager( std::shared_ptr gcs_publisher, std::shared_ptr gcs_table_storage, std::shared_ptr raylet_client_pool, + instrumented_io_context &io_context, const ClusterID &cluster_id) : gcs_publisher_(std::move(gcs_publisher)), gcs_table_storage_(std::move(gcs_table_storage)), raylet_client_pool_(std::move(raylet_client_pool)), + io_context_(io_context), cluster_id_(cluster_id) {} void GcsNodeManager::WriteNodeExportEvent(rpc::GcsNodeInfo node_info) const { @@ -115,15 +117,15 @@ void GcsNodeManager::HandleRegisterNode(rpc::RegisterNodeRequest request, [this, request, on_done, node_id](const Status &status) { RAY_CHECK_OK(status); RAY_CHECK_OK(gcs_table_storage_->NodeTable().Put( - node_id, request.node_info(), on_done)); + node_id, request.node_info(), {on_done, io_context_})); }); } else { - RAY_CHECK_OK( - gcs_table_storage_->NodeTable().Put(node_id, request.node_info(), on_done)); + RAY_CHECK_OK(gcs_table_storage_->NodeTable().Put( + node_id, request.node_info(), {on_done, io_context_})); } } else { - RAY_CHECK_OK( - gcs_table_storage_->NodeTable().Put(node_id, request.node_info(), on_done)); + RAY_CHECK_OK(gcs_table_storage_->NodeTable().Put( + node_id, request.node_info(), {on_done, io_context_})); } ++counts_[CountType::REGISTER_NODE_REQUEST]; } @@ -166,7 +168,8 @@ void GcsNodeManager::HandleUnregisterNode(rpc::UnregisterNodeRequest request, RAY_CHECK_OK(gcs_publisher_->PublishNodeInfo(node_id, *node_info_delta, nullptr)); WriteNodeExportEvent(*node); }; - RAY_CHECK_OK(gcs_table_storage_->NodeTable().Put(node_id, *node, on_put_done)); + RAY_CHECK_OK( + gcs_table_storage_->NodeTable().Put(node_id, *node, {on_put_done, io_context_})); GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); } @@ -429,7 +432,8 @@ void GcsNodeManager::OnNodeFailure(const NodeID &node_id, } RAY_CHECK_OK(gcs_publisher_->PublishNodeInfo(node_id, *node_info_delta, nullptr)); }; - RAY_CHECK_OK(gcs_table_storage_->NodeTable().Put(node_id, *node, on_done)); + RAY_CHECK_OK( + gcs_table_storage_->NodeTable().Put(node_id, *node, {on_done, io_context_})); } else if (node_table_updated_callback != nullptr) { node_table_updated_callback(Status::OK()); } @@ -466,7 +470,8 @@ void GcsNodeManager::Initialize(const GcsInitData &gcs_init_data) { void GcsNodeManager::AddDeadNodeToCache(std::shared_ptr node) { if (dead_nodes_.size() >= RayConfig::instance().maximum_gcs_dead_node_cached_count()) { const auto &node_id = sorted_dead_node_list_.begin()->first; - RAY_CHECK_OK(gcs_table_storage_->NodeTable().Delete(node_id, nullptr)); + RAY_CHECK_OK( + gcs_table_storage_->NodeTable().Delete(node_id, {[](Status) {}, io_context_})); dead_nodes_.erase(sorted_dead_node_list_.begin()->first); sorted_dead_node_list_.erase(sorted_dead_node_list_.begin()); } diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.h b/src/ray/gcs/gcs_server/gcs_node_manager.h index db258d4cb00c..912bde44b9b7 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.h +++ b/src/ray/gcs/gcs_server/gcs_node_manager.h @@ -51,6 +51,7 @@ class GcsNodeManager : public rpc::NodeInfoHandler { explicit GcsNodeManager(std::shared_ptr gcs_publisher, std::shared_ptr gcs_table_storage, std::shared_ptr raylet_client_pool, + instrumented_io_context &io_context, const ClusterID &cluster_id); /// Handle register rpc request come from raylet. @@ -249,6 +250,8 @@ class GcsNodeManager : public rpc::NodeInfoHandler { std::shared_ptr gcs_table_storage_; /// Raylet client pool. std::shared_ptr raylet_client_pool_; + /// IO context. + instrumented_io_context &io_context_; /// Cluster ID to be shared with clients when connecting. const ClusterID cluster_id_; diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc index c60bcd43cc45..659d974d3a30 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_manager.cc @@ -261,31 +261,33 @@ void GcsPlacementGroupManager::RegisterPlacementGroup( RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().Put( placement_group_id, placement_group->GetPlacementGroupTableData(), - [this, placement_group_id, placement_group](Status status) { - // The backend storage is supposed to be reliable, so the status must be ok. - RAY_CHECK_OK(status); - if (registered_placement_groups_.contains(placement_group_id)) { - auto iter = placement_group_to_register_callbacks_.find(placement_group_id); - auto callbacks = std::move(iter->second); - placement_group_to_register_callbacks_.erase(iter); - for (const auto &callback : callbacks) { - callback(status); - } - SchedulePendingPlacementGroups(); - } else { - // The placement group registration is synchronous, so if we found the placement - // group was deleted here, it must be triggered by the abnormal exit of job, - // we will return directly in this case. - RAY_CHECK(placement_group_to_register_callbacks_.count(placement_group_id) == 0) - << "The placement group has been removed unexpectedly with an unknown " - "error. Please file a bug report on here: " - "https://github.com/ray-project/ray/issues"; - RAY_LOG(WARNING) << "Failed to create placement group '" - << placement_group->GetPlacementGroupID() - << "', because the placement group has been removed by GCS."; - return; - } - })); + {[this, placement_group_id, placement_group](Status status) { + // The backend storage is supposed to be reliable, so the status must be ok. + RAY_CHECK_OK(status); + if (registered_placement_groups_.contains(placement_group_id)) { + auto iter = placement_group_to_register_callbacks_.find(placement_group_id); + auto callbacks = std::move(iter->second); + placement_group_to_register_callbacks_.erase(iter); + for (const auto &callback : callbacks) { + callback(status); + } + SchedulePendingPlacementGroups(); + } else { + // The placement group registration is synchronous, so if we found the + // placement group was deleted here, it must be triggered by the abnormal exit + // of job, we will return directly in this case. + RAY_CHECK(placement_group_to_register_callbacks_.count(placement_group_id) == + 0) + << "The placement group has been removed unexpectedly with an unknown " + "error. Please file a bug report on here: " + "https://github.com/ray-project/ray/issues"; + RAY_LOG(WARNING) << "Failed to create placement group '" + << placement_group->GetPlacementGroupID() + << "', because the placement group has been removed by GCS."; + return; + } + }, + io_context_})); } PlacementGroupID GcsPlacementGroupManager::GetPlacementGroupIDByName( @@ -370,27 +372,28 @@ void GcsPlacementGroupManager::OnPlacementGroupCreationSuccess( RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().Put( placement_group_id, placement_group->GetPlacementGroupTableData(), - [this, placement_group_id](Status status) { - RAY_CHECK_OK(status); - - if (RescheduleIfStillHasUnplacedBundles(placement_group_id)) { - // If all the bundles are not created yet, don't complete - // the creation and invoke a callback. - // The call back will be called when all bundles are created. - return; - } - // Invoke all callbacks for all `WaitPlacementGroupUntilReady` requests of this - // placement group and remove all of them from - // placement_group_to_create_callbacks_. - auto pg_to_create_iter = - placement_group_to_create_callbacks_.find(placement_group_id); - if (pg_to_create_iter != placement_group_to_create_callbacks_.end()) { - for (auto &callback : pg_to_create_iter->second) { - callback(status); - } - placement_group_to_create_callbacks_.erase(pg_to_create_iter); - } - })); + {[this, placement_group_id](Status status) { + RAY_CHECK_OK(status); + + if (RescheduleIfStillHasUnplacedBundles(placement_group_id)) { + // If all the bundles are not created yet, don't complete + // the creation and invoke a callback. + // The call back will be called when all bundles are created. + return; + } + // Invoke all callbacks for all `WaitPlacementGroupUntilReady` requests of this + // placement group and remove all of them from + // placement_group_to_create_callbacks_. + auto pg_to_create_iter = + placement_group_to_create_callbacks_.find(placement_group_id); + if (pg_to_create_iter != placement_group_to_create_callbacks_.end()) { + for (auto &callback : pg_to_create_iter->second) { + callback(status); + } + placement_group_to_create_callbacks_.erase(pg_to_create_iter); + } + }, + io_context_})); lifetime_num_placement_groups_created_++; io_context_.post([this] { SchedulePendingPlacementGroups(); }, "GcsPlacementGroupManager.SchedulePendingPlacementGroups"); @@ -558,20 +561,21 @@ void GcsPlacementGroupManager::RemovePlacementGroup( RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().Put( placement_group->GetPlacementGroupID(), placement_group->GetPlacementGroupTableData(), - [this, on_placement_group_removed, placement_group_id](Status status) { - RAY_CHECK_OK(status); - // If there is a driver waiting for the creation done, then send a message that - // the placement group has been removed. - auto it = placement_group_to_create_callbacks_.find(placement_group_id); - if (it != placement_group_to_create_callbacks_.end()) { - for (auto &callback : it->second) { - callback( - Status::NotFound("Placement group is removed before it is created.")); - } - placement_group_to_create_callbacks_.erase(it); - } - on_placement_group_removed(status); - })); + {[this, on_placement_group_removed, placement_group_id](Status status) { + RAY_CHECK_OK(status); + // If there is a driver waiting for the creation done, then send a message that + // the placement group has been removed. + auto it = placement_group_to_create_callbacks_.find(placement_group_id); + if (it != placement_group_to_create_callbacks_.end()) { + for (auto &callback : it->second) { + callback( + Status::NotFound("Placement group is removed before it is created.")); + } + placement_group_to_create_callbacks_.erase(it); + } + on_placement_group_removed(status); + }, + io_context_})); } void GcsPlacementGroupManager::HandleGetPlacementGroup( @@ -598,8 +602,8 @@ void GcsPlacementGroupManager::HandleGetPlacementGroup( if (it != registered_placement_groups_.end()) { on_done(Status::OK(), it->second->GetPlacementGroupTableData()); } else { - Status status = - gcs_table_storage_->PlacementGroupTable().Get(placement_group_id, on_done); + Status status = gcs_table_storage_->PlacementGroupTable().Get(placement_group_id, + {on_done, io_context_}); if (!status.ok()) { on_done(status, std::nullopt); } @@ -669,7 +673,8 @@ void GcsPlacementGroupManager::HandleGetAllPlacementGroup( RAY_LOG(DEBUG) << "Finished getting all placement group info."; GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); }; - Status status = gcs_table_storage_->PlacementGroupTable().GetAll(on_done); + Status status = + gcs_table_storage_->PlacementGroupTable().GetAll({on_done, io_context_}); if (!status.ok()) { on_done(absl::flat_hash_map()); } @@ -729,8 +734,8 @@ void GcsPlacementGroupManager::WaitPlacementGroup( } }; - Status status = - gcs_table_storage_->PlacementGroupTable().Get(placement_group_id, on_done); + Status status = gcs_table_storage_->PlacementGroupTable().Get(placement_group_id, + {on_done, io_context_}); if (!status.ok()) { on_done(status, std::nullopt); } @@ -809,7 +814,7 @@ void GcsPlacementGroupManager::OnNodeDead(const NodeID &node_id) { RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().Put( iter->second->GetPlacementGroupID(), iter->second->GetPlacementGroupTableData(), - [this](Status status) { SchedulePendingPlacementGroups(); })); + {[this](Status) { SchedulePendingPlacementGroups(); }, io_context_})); } } } @@ -1118,7 +1123,7 @@ bool GcsPlacementGroupManager::RescheduleIfStillHasUnplacedBundles( RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().Put( placement_group->GetPlacementGroupID(), placement_group->GetPlacementGroupTableData(), - [this](Status status) { SchedulePendingPlacementGroups(); })); + {[this](Status status) { SchedulePendingPlacementGroups(); }, io_context_})); return true; } } diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc index 6bc2737c14a6..9cbf41f74ea9 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc @@ -398,12 +398,13 @@ void GcsPlacementGroupScheduler::OnAllBundlePrepareRequestReturned( RAY_CHECK_OK(gcs_table_storage_->PlacementGroupTable().Put( placement_group_id, placement_group->GetPlacementGroupTableData(), - [this, lease_status_tracker, schedule_failure_handler, schedule_success_handler]( - Status status) { - RAY_CHECK_OK(status); - CommitAllBundles( - lease_status_tracker, schedule_failure_handler, schedule_success_handler); - })); + {[this, lease_status_tracker, schedule_failure_handler, schedule_success_handler]( + Status status) { + RAY_CHECK_OK(status); + CommitAllBundles( + lease_status_tracker, schedule_failure_handler, schedule_success_handler); + }, + io_context_})); } void GcsPlacementGroupScheduler::OnAllBundleCommitRequestReturned( diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 5ef57e905f95..3e8b53552aca 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -102,7 +102,9 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, ray::rpc::StoredConfig stored_config; stored_config.set_config(config_.raylet_config_list); RAY_CHECK_OK(gcs_table_storage_->InternalConfigTable().Put( - ray::UniqueID::Nil(), stored_config, on_done)); + ray::UniqueID::Nil(), + stored_config, + {on_done, io_context_provider_.GetDefaultIOContext()})); // Here we need to make sure the Put of internal config is happening in sync // way. But since the storage API is async, we need to run the default io context // to block current thread. @@ -152,12 +154,14 @@ void GcsServer::Start() { // Init KV Manager. This needs to be initialized first here so that // it can be used to retrieve the cluster ID. InitKVManager(); - gcs_init_data->AsyncLoad([this, gcs_init_data] { - GetOrGenerateClusterId([this, gcs_init_data](ClusterID cluster_id) { - rpc_server_.SetClusterId(cluster_id); - DoStart(*gcs_init_data); - }); - }); + gcs_init_data->AsyncLoad( + [this, gcs_init_data] { + GetOrGenerateClusterId([this, gcs_init_data](ClusterID cluster_id) { + rpc_server_.SetClusterId(cluster_id); + DoStart(*gcs_init_data); + }); + }, + io_context_provider_.GetDefaultIOContext()); } void GcsServer::GetOrGenerateClusterId( @@ -301,10 +305,12 @@ void GcsServer::Stop() { void GcsServer::InitGcsNodeManager(const GcsInitData &gcs_init_data) { RAY_CHECK(gcs_table_storage_ && gcs_publisher_); - gcs_node_manager_ = std::make_unique(gcs_publisher_, - gcs_table_storage_, - raylet_client_pool_, - rpc_server_.GetClusterId()); + gcs_node_manager_ = + std::make_unique(gcs_publisher_, + gcs_table_storage_, + raylet_client_pool_, + io_context_provider_.GetDefaultIOContext(), + rpc_server_.GetClusterId()); // Initialize by gcs tables data. gcs_node_manager_->Initialize(gcs_init_data); // Register service. @@ -489,6 +495,7 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) { [this](const ActorID &actor_id) { gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenActorDead(actor_id); }, + io_context_provider_.GetDefaultIOContext(), [this](const rpc::Address &address) { return std::make_shared(address, client_call_manager_); }); @@ -556,7 +563,8 @@ void GcsServer::InitRaySyncer(const GcsInitData &gcs_init_data) { } void GcsServer::InitFunctionManager() { - function_manager_ = std::make_unique(kv_manager_->GetInstance()); + function_manager_ = std::make_unique( + kv_manager_->GetInstance(), io_context_provider_.GetDefaultIOContext()); } void GcsServer::InitUsageStatsClient() { @@ -579,9 +587,8 @@ void GcsServer::InitKVManager() { std::make_unique(CreateRedisClient(io_context))); break; case (StorageType::IN_MEMORY): - instance = - std::make_unique(std::make_unique( - std::make_unique(io_context))); + instance = std::make_unique( + std::make_unique(std::make_unique())); break; default: RAY_LOG(FATAL) << "Unexpected storage type! " << storage_type_; @@ -650,8 +657,8 @@ void GcsServer::InitRuntimeEnvManager() { } void GcsServer::InitGcsWorkerManager() { - gcs_worker_manager_ = - std::make_unique(gcs_table_storage_, gcs_publisher_); + gcs_worker_manager_ = std::make_unique( + gcs_table_storage_, gcs_publisher_, io_context_provider_.GetDefaultIOContext()); // Register service. worker_info_service_.reset(new rpc::WorkerInfoGrpcService( io_context_provider_.GetDefaultIOContext(), *gcs_worker_manager_)); diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.cc b/src/ray/gcs/gcs_server/gcs_table_storage.cc index 7e78923581f6..3e8ef7f3ccbf 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.cc +++ b/src/ray/gcs/gcs_server/gcs_table_storage.cc @@ -37,29 +37,26 @@ template Status GcsTable::Get(const Key &key, ToPostable> callback) { return store_client_->AsyncGet( - table_name_, - key.Binary(), - std::move(callback).Rebind &&)>( - [](auto cb) { - return - [cb = std::move(cb)](Status status, std::optional &&result) { - std::optional value; - if (result) { - Data data; - data.ParseFromString(*result); - value = std::move(data); - } - cb(status, std::move(value)); - }; - })); + table_name_, key.Binary(), std::move(callback).Rebind([](auto cb) { + return [cb = std::move(cb)](Status status, std::optional &&result) { + std::optional value; + if (result) { + Data data; + data.ParseFromString(*result); + value = std::move(data); + } + cb(status, std::move(value)); + }; + })); } template -Status GcsTable::GetAll(ToPostable> callback) { +Status GcsTable::GetAll( + Postable)> callback) { return store_client_->AsyncGetAll( table_name_, - std::move(callback).TransformArg &&>( - [](absl::flat_hash_map &&result) { + std::move(callback).TransformArg( + [](absl::flat_hash_map result) { absl::flat_hash_map values; values.reserve(result.size()); for (auto &item : result) { @@ -67,14 +64,14 @@ Status GcsTable::GetAll(ToPostable> callback) values[Key::FromBinary(item.first)].ParseFromString(item.second); } } - return std::move(values); + return values; })); } template Status GcsTable::Delete(const Key &key, Postable callback) { return store_client_->AsyncDelete( - table_name_, key.Binary(), std::move(callback).TransformArg([](bool) { + table_name_, key.Binary(), std::move(callback).TransformArg([](bool) { return Status::OK(); })); } @@ -88,9 +85,9 @@ Status GcsTable::BatchDelete(const std::vector &keys, keys_to_delete.emplace_back(std::move(key.Binary())); } return this->store_client_->AsyncBatchDelete( - this->table_name_, - keys_to_delete, - std::move(callback).TransformArg([](int64_t) { return Status::OK(); })); + this->table_name_, keys_to_delete, std::move(callback).TransformArg([](int64_t) { + return Status::OK(); + })); } template @@ -106,12 +103,12 @@ Status GcsTableWithJobId::Put(const Key &key, key.Binary(), value.SerializeAsString(), /*overwrite*/ true, - std::move(callback).TransformArg([](bool) { return Status::OK(); })); + std::move(callback).TransformArg([](bool) { return Status::OK(); })); } template Status GcsTableWithJobId::GetByJobId( - const JobID &job_id, ToPostable> callback) { + const JobID &job_id, Postable)> callback) { std::vector keys; { absl::MutexLock lock(&mutex_); @@ -120,19 +117,20 @@ Status GcsTableWithJobId::GetByJobId( keys.push_back(key.Binary()); } } - auto on_done = [callback](absl::flat_hash_map &&result) { - if (!callback) { - return; - } - absl::flat_hash_map values; - for (auto &item : result) { - if (!item.second.empty()) { - values[Key::FromBinary(item.first)].ParseFromString(item.second); - } - } - callback(std::move(values)); - }; - return this->store_client_->AsyncMultiGet(this->table_name_, keys, on_done); + + return this->store_client_->AsyncMultiGet( + this->table_name_, + keys, + std::move(callback).TransformArg( + [](absl::flat_hash_map result) { + absl::flat_hash_map values; + for (auto &item : result) { + if (!item.second.empty()) { + values[Key::FromBinary(item.first)].ParseFromString(item.second); + } + } + return std::move(values); + })); } template @@ -166,7 +164,7 @@ Status GcsTableWithJobId::BatchDelete(const std::vector &keys, return this->store_client_->AsyncBatchDelete( this->table_name_, keys_to_delete, - std::move(callback).TransformArg([this, keys](auto) { + std::move(callback).TransformArg([this, keys](int64_t) { { absl::MutexLock lock(&mutex_); for (auto &key : keys) { @@ -179,19 +177,17 @@ Status GcsTableWithJobId::BatchDelete(const std::vector &keys, template Status GcsTableWithJobId::AsyncRebuildIndexAndGetAll( - ToPostable> callback) { - return this->GetAll([this, callback](absl::flat_hash_map &&result) mutable { - absl::MutexLock lock(&mutex_); - index_.clear(); - for (auto &item : result) { - auto key = item.first; - index_[GetJobIdFromKey(key)].insert(key); - } - if (!callback) { - return; - } - callback(std::move(result)); - }); + Postable)> callback) { + return this->GetAll( + std::move(callback).TransformArg([this](absl::flat_hash_map result) { + absl::MutexLock lock(&mutex_); + index_.clear(); + for (auto &item : result) { + auto key = item.first; + index_[GetJobIdFromKey(key)].insert(key); + } + return result; + })); } template class GcsTable; diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.h b/src/ray/gcs/gcs_server/gcs_table_storage.h index 0eab3b061f13..43f50f62ad7d 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.h +++ b/src/ray/gcs/gcs_server/gcs_table_storage.h @@ -69,7 +69,7 @@ class GcsTable { /// /// \param callback Callback that will be called after data has been received. /// \return Status - Status GetAll(ToPostable> callback); + Status GetAll(Postable)> callback); /// Delete data from the table asynchronously. /// @@ -120,7 +120,8 @@ class GcsTableWithJobId : public GcsTable { /// \param job_id The key to lookup from the table. /// \param callback Callback that will be called after read finishes. /// \return Status - Status GetByJobId(const JobID &job_id, ToPostable> callback); + Status GetByJobId(const JobID &job_id, + Postable)> callback); /// Delete all the data of the specified job id from the table asynchronously. /// @@ -145,7 +146,8 @@ class GcsTableWithJobId : public GcsTable { Postable callback) override; /// Rebuild the index during startup. - Status AsyncRebuildIndexAndGetAll(ToPostable> callback); + Status AsyncRebuildIndexAndGetAll( + Postable)> callback); protected: virtual JobID GetJobIdFromKey(const Key &key) = 0; @@ -304,8 +306,8 @@ class InMemoryGcsTableStorage : public GcsTableStorage { public: explicit InMemoryGcsTableStorage(instrumented_io_context &main_io_service) : GcsTableStorage(std::make_shared( - std::make_unique()) - ), io_service_(main_io_service) {} + std::make_unique())), + io_service_(main_io_service) {} // All methods are posted to this io_service. instrumented_io_context &io_service_; diff --git a/src/ray/gcs/gcs_server/gcs_worker_manager.cc b/src/ray/gcs/gcs_server/gcs_worker_manager.cc index 00bb62a9e35a..980b0813ea20 100644 --- a/src/ray/gcs/gcs_server/gcs_worker_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_worker_manager.cc @@ -104,7 +104,7 @@ void GcsWorkerManager::HandleReportWorkerFailure( // message, so we delete the get operation. Related issues: // https://github.com/ray-project/ray/pull/11599 Status status = gcs_table_storage_->WorkerTable().Put( - worker_id, *worker_failure_data, on_done); + worker_id, *worker_failure_data, {on_done, io_context_}); if (!status.ok()) { on_done(status); } @@ -196,7 +196,7 @@ void GcsWorkerManager::HandleGetAllWorkerInfo( RAY_LOG(DEBUG) << "Finished getting all worker info."; GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); }; - Status status = gcs_table_storage_->WorkerTable().GetAll(on_done); + Status status = gcs_table_storage_->WorkerTable().GetAll({on_done, io_context_}); if (!status.ok()) { on_done(absl::flat_hash_map()); } @@ -220,7 +220,8 @@ void GcsWorkerManager::HandleAddWorkerInfo(rpc::AddWorkerInfoRequest request, GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); }; - Status status = gcs_table_storage_->WorkerTable().Put(worker_id, *worker_data, on_done); + Status status = gcs_table_storage_->WorkerTable().Put( + worker_id, *worker_data, {on_done, io_context_}); if (!status.ok()) { on_done(status); } @@ -258,14 +259,15 @@ void GcsWorkerManager::HandleUpdateWorkerDebuggerPort( worker_data->CopyFrom(*result); worker_data->set_debugger_port(debugger_port); Status status = gcs_table_storage_->WorkerTable().Put( - worker_id, *worker_data, on_worker_update_done); + worker_id, *worker_data, {on_worker_update_done, io_context_}); if (!status.ok()) { GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); } } }; - Status status = gcs_table_storage_->WorkerTable().Get(worker_id, on_worker_get_done); + Status status = + gcs_table_storage_->WorkerTable().Get(worker_id, {on_worker_get_done, io_context_}); if (!status.ok()) { GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); } @@ -314,14 +316,15 @@ void GcsWorkerManager::HandleUpdateWorkerNumPausedThreads( worker_data->set_num_paused_threads(current_num_paused_threads + num_paused_threads_delta); Status status = gcs_table_storage_->WorkerTable().Put( - worker_id, *worker_data, on_worker_update_done); + worker_id, *worker_data, {on_worker_update_done, io_context_}); if (!status.ok()) { GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); } } }; - Status status = gcs_table_storage_->WorkerTable().Get(worker_id, on_worker_get_done); + Status status = + gcs_table_storage_->WorkerTable().Get(worker_id, {on_worker_get_done, io_context_}); if (!status.ok()) { GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); } @@ -347,7 +350,8 @@ void GcsWorkerManager::GetWorkerInfo( } }; - Status status = gcs_table_storage_->WorkerTable().Get(worker_id, on_done); + Status status = + gcs_table_storage_->WorkerTable().Get(worker_id, {on_done, io_context_}); if (!status.ok()) { on_done(status, std::nullopt); } diff --git a/src/ray/gcs/gcs_server/gcs_worker_manager.h b/src/ray/gcs/gcs_server/gcs_worker_manager.h index a835b892a682..68bdc07d65fe 100644 --- a/src/ray/gcs/gcs_server/gcs_worker_manager.h +++ b/src/ray/gcs/gcs_server/gcs_worker_manager.h @@ -27,8 +27,11 @@ namespace gcs { class GcsWorkerManager : public rpc::WorkerInfoHandler { public: explicit GcsWorkerManager(std::shared_ptr gcs_table_storage, - std::shared_ptr &gcs_publisher) - : gcs_table_storage_(gcs_table_storage), gcs_publisher_(gcs_publisher) {} + std::shared_ptr &gcs_publisher, + instrumented_io_context &io_context) + : gcs_table_storage_(gcs_table_storage), + gcs_publisher_(gcs_publisher), + io_context_(io_context) {} void HandleReportWorkerFailure(rpc::ReportWorkerFailureRequest request, rpc::ReportWorkerFailureReply *reply, @@ -70,6 +73,7 @@ class GcsWorkerManager : public rpc::WorkerInfoHandler { std::shared_ptr gcs_table_storage_; std::shared_ptr gcs_publisher_; + instrumented_io_context &io_context_; UsageStatsClient *usage_stats_client_; std::vector)>> worker_dead_listeners_; diff --git a/src/ray/gcs/gcs_server/store_client_kv.cc b/src/ray/gcs/gcs_server/store_client_kv.cc index 4bbebcf5cdf1..e8b6f08c83f2 100644 --- a/src/ray/gcs/gcs_server/store_client_kv.cc +++ b/src/ray/gcs/gcs_server/store_client_kv.cc @@ -57,11 +57,11 @@ void StoreClientInternalKV::Get(const std::string &ns, RAY_CHECK_OK(delegate_->AsyncGet( table_name_, MakeKey(ns, key), - [callback = std::move(callback)](auto status, auto result) { - callback.Post("StoreClientInternalKV::Get", - result.has_value() ? std::optional(result.value()) - : std::optional()); - })); + std::move(callback).Rebind([](std::function)> cb) { + return [cb = std::move(cb)](Status status, std::optional &&result) { + cb(std::move(result)); + }; + }))); } void StoreClientInternalKV::MultiGet( @@ -74,13 +74,16 @@ void StoreClientInternalKV::MultiGet( prefixed_keys.emplace_back(MakeKey(ns, key)); } RAY_CHECK_OK(delegate_->AsyncMultiGet( - table_name_, prefixed_keys, [callback = std::move(callback)](auto result) { - std::unordered_map ret; - for (const auto &item : result) { - ret.emplace(ExtractKey(item.first), item.second); - } - callback.Post("StoreClientInternalKV::MultiGet", std::move(ret)); - })); + table_name_, + prefixed_keys, + std::move(callback).TransformArg( + [](absl::flat_hash_map result) { + std::unordered_map ret; + for (const auto &item : result) { + ret.emplace(ExtractKey(item.first), item.second); + } + return ret; + }))); } void StoreClientInternalKV::Put(const std::string &ns, @@ -88,13 +91,8 @@ void StoreClientInternalKV::Put(const std::string &ns, const std::string &value, bool overwrite, Postable callback) { - RAY_CHECK_OK(delegate_->AsyncPut(table_name_, - MakeKey(ns, key), - value, - overwrite, - [callback = std::move(callback)](bool success) { - callback.Post("StoreClientInternalKV::Put", success); - })); + RAY_CHECK_OK(delegate_->AsyncPut( + table_name_, MakeKey(ns, key), value, overwrite, std::move(callback))); } void StoreClientInternalKV::Del(const std::string &ns, @@ -104,7 +102,8 @@ void StoreClientInternalKV::Del(const std::string &ns, if (!del_by_prefix) { RAY_CHECK_OK(delegate_->AsyncDelete( table_name_, MakeKey(ns, key), std::move(callback).TransformArg([](bool deleted) { - return deleted ? 1 : 0; + int64_t ret = deleted ? 1 : 0; + return ret; }))); return; } @@ -140,15 +139,14 @@ void StoreClientInternalKV::Keys(const std::string &ns, RAY_CHECK_OK(delegate_->AsyncGetKeys( table_name_, MakeKey(ns, prefix), - std::move(callback).TransformArg( - [](const std::vector &keys) -> std::vector { - std::vector true_keys; - true_keys.reserve(keys.size()); - for (auto &key : keys) { - true_keys.emplace_back(ExtractKey(key)); - } - return true_keys; - }))); + std::move(callback).TransformArg([](std::vector keys) { + std::vector true_keys; + true_keys.reserve(keys.size()); + for (auto &key : keys) { + true_keys.emplace_back(ExtractKey(key)); + } + return true_keys; + }))); } } // namespace gcs diff --git a/src/ray/gcs/store_client/in_memory_store_client.cc b/src/ray/gcs/store_client/in_memory_store_client.cc index 85ec2208181e..268cc83f6c1f 100644 --- a/src/ray/gcs/store_client/in_memory_store_client.cc +++ b/src/ray/gcs/store_client/in_memory_store_client.cc @@ -56,7 +56,7 @@ Status InMemoryStoreClient::AsyncGet( Status InMemoryStoreClient::AsyncGetAll( const std::string &table_name, - ToPostable> callback) { + Postable)> callback) { auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); auto result = absl::flat_hash_map(); @@ -68,7 +68,7 @@ Status InMemoryStoreClient::AsyncGetAll( Status InMemoryStoreClient::AsyncMultiGet( const std::string &table_name, const std::vector &keys, - ToPostable> callback) { + Postable)> callback) { auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); auto result = absl::flat_hash_map(); diff --git a/src/ray/gcs/store_client/in_memory_store_client.h b/src/ray/gcs/store_client/in_memory_store_client.h index 612ab8577feb..ce5b5798b4ff 100644 --- a/src/ray/gcs/store_client/in_memory_store_client.h +++ b/src/ray/gcs/store_client/in_memory_store_client.h @@ -42,13 +42,14 @@ class InMemoryStoreClient : public StoreClient { const std::string &key, ToPostable> callback) override; - Status AsyncGetAll(const std::string &table_name, - ToPostable> callback) override; + Status AsyncGetAll( + const std::string &table_name, + Postable)> callback) override; Status AsyncMultiGet( const std::string &table_name, const std::vector &keys, - ToPostable> callback) override; + Postable)> callback) override; Status AsyncDelete(const std::string &table_name, const std::string &key, diff --git a/src/ray/gcs/store_client/observable_store_client.cc b/src/ray/gcs/store_client/observable_store_client.cc index 98a5371431ca..1d6c78fe31b9 100644 --- a/src/ray/gcs/store_client/observable_store_client.cc +++ b/src/ray/gcs/store_client/observable_store_client.cc @@ -52,7 +52,7 @@ Status ObservableStoreClient::AsyncGet( Status ObservableStoreClient::AsyncGetAll( const std::string &table_name, - ToPostable> callback) { + Postable)> callback) { auto start = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_count.Record(1, "GetAll"); return delegate_->AsyncGetAll(table_name, std::move(callback).OnInvocation([start]() { @@ -65,7 +65,7 @@ Status ObservableStoreClient::AsyncGetAll( Status ObservableStoreClient::AsyncMultiGet( const std::string &table_name, const std::vector &keys, - ToPostable> callback) { + Postable)> callback) { auto start = absl::GetCurrentTimeNanos(); STATS_gcs_storage_operation_count.Record(1, "MultiGet"); return delegate_->AsyncMultiGet( diff --git a/src/ray/gcs/store_client/observable_store_client.h b/src/ray/gcs/store_client/observable_store_client.h index 388ad7cc3271..9df47b04a0f2 100644 --- a/src/ray/gcs/store_client/observable_store_client.h +++ b/src/ray/gcs/store_client/observable_store_client.h @@ -36,13 +36,14 @@ class ObservableStoreClient : public StoreClient { const std::string &key, ToPostable> callback) override; - Status AsyncGetAll(const std::string &table_name, - ToPostable> callback) override; + Status AsyncGetAll( + const std::string &table_name, + Postable)> callback) override; Status AsyncMultiGet( const std::string &table_name, const std::vector &keys, - ToPostable> callback) override; + Postable)> callback) override; Status AsyncDelete(const std::string &table_name, const std::string &key, diff --git a/src/ray/gcs/store_client/redis_store_client.cc b/src/ray/gcs/store_client/redis_store_client.cc index f715692fe827..0dd37d01e6e6 100644 --- a/src/ray/gcs/store_client/redis_store_client.cc +++ b/src/ray/gcs/store_client/redis_store_client.cc @@ -100,16 +100,16 @@ Status RedisStoreClient::AsyncGet( const std::string &table_name, const std::string &key, ToPostable> callback) { - auto redis_callback = [callback = std::move(callback)]( - const std::shared_ptr &reply) { - std::optional result; - if (!reply->IsNil()) { - result = reply->ReadAsString(); - } - RAY_CHECK(!reply->IsError()) - << "Failed to get from Redis with status: " << reply->ReadAsStatus(); - callback.Post("RedisStoreClient.AsyncGet", Status::OK(), std::move(result)); - }; + auto redis_callback = + [callback = std::move(callback)](const std::shared_ptr &reply) { + std::optional result; + if (!reply->IsNil()) { + result = reply->ReadAsString(); + } + RAY_CHECK(!reply->IsError()) + << "Failed to get from Redis with status: " << reply->ReadAsStatus(); + callback.Post("RedisStoreClient.AsyncGet", Status::OK(), std::move(result)); + }; RedisCommand command{/*command=*/"HGET", RedisKey{external_storage_namespace_, table_name}, @@ -120,7 +120,7 @@ Status RedisStoreClient::AsyncGet( Status RedisStoreClient::AsyncGetAll( const std::string &table_name, - ToPostable> callback) { + Postable)> callback) { RedisScanner::ScanKeysAndValues(redis_client_, RedisKey{external_storage_namespace_, table_name}, RedisMatchPattern::Any(), @@ -131,10 +131,10 @@ Status RedisStoreClient::AsyncGetAll( Status RedisStoreClient::AsyncDelete(const std::string &table_name, const std::string &key, Postable callback) { - return AsyncBatchDelete(table_name, - {key}, - std::move(callback).TransformArg( - std::function{[](int64_t cnt) { return cnt > 0; }})); + return AsyncBatchDelete( + table_name, {key}, std::move(callback).TransformArg([](int64_t cnt) { + return cnt > 0; + })); } Status RedisStoreClient::AsyncBatchDelete(const std::string &table_name, @@ -150,7 +150,7 @@ Status RedisStoreClient::AsyncBatchDelete(const std::string &table_name, Status RedisStoreClient::AsyncMultiGet( const std::string &table_name, const std::vector &keys, - ToPostable> callback) { + Postable)> callback) { if (keys.empty()) { callback.Post("RedisStoreClient.AsyncMultiGet", absl::flat_hash_map{}); @@ -337,7 +337,7 @@ RedisStoreClient::RedisScanner::RedisScanner( std::shared_ptr redis_client, RedisKey redis_key, RedisMatchPattern match_pattern, - ToPostable> callback) + Postable)> callback) : redis_key_(std::move(redis_key)), match_pattern_(std::move(match_pattern)), redis_client_(std::move(redis_client)), @@ -350,7 +350,7 @@ void RedisStoreClient::RedisScanner::ScanKeysAndValues( std::shared_ptr redis_client, RedisKey redis_key, RedisMatchPattern match_pattern, - ToPostable> callback) { + Postable)> callback) { auto scanner = std::make_shared(PrivateCtorTag(), std::move(redis_client), std::move(redis_key), @@ -444,15 +444,14 @@ Status RedisStoreClient::AsyncGetKeys(const std::string &table_name, RedisKey{external_storage_namespace_, table_name}, RedisMatchPattern::Prefix(prefix), std::move(callback).TransformArg( - std::function{[](absl::flat_hash_map &&result) - -> std::vector { + [](absl::flat_hash_map result) { std::vector keys; keys.reserve(result.size()); for (const auto &[k, v] : result) { keys.push_back(k); } return keys; - }})); + })); return Status::OK(); } diff --git a/src/ray/gcs/store_client/redis_store_client.h b/src/ray/gcs/store_client/redis_store_client.h index 3454da037652..e52e1f77ec2c 100644 --- a/src/ray/gcs/store_client/redis_store_client.h +++ b/src/ray/gcs/store_client/redis_store_client.h @@ -118,13 +118,14 @@ class RedisStoreClient : public StoreClient { const std::string &key, ToPostable> callback) override; - Status AsyncGetAll(const std::string &table_name, - ToPostable> callback) override; + Status AsyncGetAll( + const std::string &table_name, + Postable)> callback) override; Status AsyncMultiGet( const std::string &table_name, const std::vector &keys, - ToPostable> callback) override; + Postable)> callback) override; Status AsyncDelete(const std::string &table_name, const std::string &key, @@ -160,17 +161,18 @@ class RedisStoreClient : public StoreClient { public: // Don't call this. Use ScanKeysAndValues instead. - explicit RedisScanner(PrivateCtorTag tag, - std::shared_ptr redis_client, - RedisKey redis_key, - RedisMatchPattern match_pattern, - ToPostable> callback); + explicit RedisScanner( + PrivateCtorTag tag, + std::shared_ptr redis_client, + RedisKey redis_key, + RedisMatchPattern match_pattern, + Postable)> callback); static void ScanKeysAndValues( std::shared_ptr redis_client, RedisKey redis_key, RedisMatchPattern match_pattern, - ToPostable> callback); + Postable)> callback); private: // Scans the keys and values, one batch a time. Once all keys are scanned, the @@ -200,7 +202,7 @@ class RedisStoreClient : public StoreClient { std::shared_ptr redis_client_; - ToPostable> callback_; + Postable)> callback_; // Holds a self-ref until the scan is done. std::shared_ptr self_ref_; diff --git a/src/ray/gcs/store_client/store_client.h b/src/ray/gcs/store_client/store_client.h index d0f7b87eddb2..9326d26df4db 100644 --- a/src/ray/gcs/store_client/store_client.h +++ b/src/ray/gcs/store_client/store_client.h @@ -68,7 +68,7 @@ class StoreClient { /// \return Status virtual Status AsyncGetAll( const std::string &table_name, - ToPostable> callback) = 0; + Postable)> callback) = 0; /// Get all data from the given table asynchronously. /// @@ -79,7 +79,7 @@ class StoreClient { virtual Status AsyncMultiGet( const std::string &table_name, const std::vector &keys, - ToPostable> callback) = 0; + Postable)> callback) = 0; /// Delete data from the given table asynchronously. /// From 1572f8f0e08175cbe1dae893665b85594e85b023 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Mon, 25 Nov 2024 22:15:23 -0800 Subject: [PATCH 21/26] add .h and revert accidental change Signed-off-by: Ruiyang Wang --- .../execution/operators/output_splitter.py | 2 +- src/ray/common/asio/postable.h | 110 ++++++++++++++++++ src/ray/util/function_traits.h | 68 +++++++++++ 3 files changed, 179 insertions(+), 1 deletion(-) create mode 100644 src/ray/common/asio/postable.h create mode 100644 src/ray/util/function_traits.h diff --git a/python/ray/data/_internal/execution/operators/output_splitter.py b/python/ray/data/_internal/execution/operators/output_splitter.py index 833557f6e203..f5a9b6c55d84 100644 --- a/python/ray/data/_internal/execution/operators/output_splitter.py +++ b/python/ray/data/_internal/execution/operators/output_splitter.py @@ -161,7 +161,7 @@ def progress_str(self) -> str: def _dispatch_bundles(self, dispatch_all: bool = False) -> None: start_time = time.perf_counter() - # Dispatch all postable bundles from the internal buffer. + # Dispatch all dispatchable bundles from the internal buffer. # This may not dispatch all bundles when equal=True. while self._buffer and ( dispatch_all or len(self._buffer) >= self._min_buffer_size diff --git a/src/ray/common/asio/postable.h b/src/ray/common/asio/postable.h new file mode 100644 index 000000000000..6fc0512e9a7d --- /dev/null +++ b/src/ray/common/asio/postable.h @@ -0,0 +1,110 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "ray/common/asio/instrumented_io_context.h" +#include "ray/util/function_traits.h" + +namespace ray { + +template +class Postable; + +namespace internal { + +template +struct ToPostableHelper; + +template +struct ToPostableHelper> { + using type = Postable; +}; + +} // namespace internal + +template +using ToPostable = typename internal::ToPostableHelper::type; + +template +class Postable { + public: + Postable(std::function func, instrumented_io_context &io_context) + : func_(std::move(func)), io_context_(io_context) {} + + template + void Post(const std::string &name, Args &&...args) const { + io_context_.post( + [func = func_, + args_tuple = std::make_tuple(std::forward(args)...)]() mutable { + std::apply(func, std::move(args_tuple)); + }, + name); + } + + // OnInvocation + Postable &&OnInvocation(std::function observer) && { + auto original_func = std::move(func_); + func_ = [observer = std::move(observer), + func = std::move(original_func)](auto &&...args) { + observer(); + return func(std::forward(args)...); + }; + return std::move(*this); + } + + // Transforms the argument by applying `arg_mapper` to the input argument. + // Basically, adds a arg_mapper and becomes io_context.Post(func(arg_mapper(input))). + // + // Constraints in template arguments: + // - `this->func_` must take exactly one argument. + // - `arg_mapper` must take one argument. + // - `arg_mapper` must return the same type as `this->func_`'s argument. + // + // Result: + // `this` is Postable + // `arg_mapper` is lambda or std::function: NewInputType -> OldInputType + // The result is Postable + template + auto TransformArg(ArgMapper arg_mapper) && { + // Ensure that func_ takes exactly one argument. + static_assert(function_traits::arity == 1, + "TransformArg requires function taking exactly one argument"); + + // Ensure that arg_mapper takes exactly one argument. + static_assert(function_traits::arity == 1, + "ArgMapper must be a function taking exactly one argument"); + // Define type aliases for clarity. + using OldInputType = typename function_traits::arg1_type; + using NewInputType = typename function_traits::arg1_type; + using ArgMapperResultType = typename function_traits::result_type; + + static_assert(std::is_same_v, + "ArgMapper's return value must == func_'s argument"); + + return std::move(*this).Rebind([arg_mapper = std::move(arg_mapper)](auto func) { + return [func = std::move(func), arg_mapper = std::move(arg_mapper)]( + NewInputType input) { return func(arg_mapper(std::move(input))); }; + }); + } + + // Rebind the function. + // `func_converter`: func_ -> NewFuncType + // The result is ToPostable + // + // Changed func_converter to be a template parameter to accept lambdas. + template + auto Rebind(FuncConverter func_converter) && { // -> Postable + using NewFuncType = typename function_traits()))>::type; + return Postable(func_converter(std::move(func_)), io_context_); + } + + std::function func_; + instrumented_io_context &io_context_; +}; + +} // namespace ray diff --git a/src/ray/util/function_traits.h b/src/ray/util/function_traits.h new file mode 100644 index 000000000000..2ac1706e4d8f --- /dev/null +++ b/src/ray/util/function_traits.h @@ -0,0 +1,68 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace ray { + +namespace internal { + +template +struct function_traits_helper : public boost::function_traits { + using type = T; + using std_function_type = std::function; +}; + +} // namespace internal + +// Generalized boost::function_traits to support more callable types. +// - function pointers +// - std::function +// - member function pointers +// - lambdas and callable objects, or anything with an operator() +template +struct function_traits; + +// Specialization for function pointers +template +struct function_traits : internal::function_traits_helper {}; + +template +struct function_traits : internal::function_traits_helper {}; + +// Specialization for std::function +template +struct function_traits> + : internal::function_traits_helper {}; + +// Specialization for member function pointers +template +struct function_traits : internal::function_traits_helper { +}; + +// Specialization for const member function pointers +template +struct function_traits + : internal::function_traits_helper {}; + +// Specialization for callable objects (e.g., lambdas and functors) +template +struct function_traits + : function_traits {}; + +} // namespace ray From e447b83b423982c044b7fdb9e4adaaa9b95f79eb Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Tue, 26 Nov 2024 12:22:16 -0800 Subject: [PATCH 22/26] lint Signed-off-by: Ruiyang Wang --- src/ray/common/asio/postable.h | 14 ++++++++++++++ src/ray/gcs/gcs_server/gcs_table_storage.cc | 4 ++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/ray/common/asio/postable.h b/src/ray/common/asio/postable.h index 6fc0512e9a7d..a88f045a4fc9 100644 --- a/src/ray/common/asio/postable.h +++ b/src/ray/common/asio/postable.h @@ -1,3 +1,17 @@ +// Copyright 2024 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.cc b/src/ray/gcs/gcs_server/gcs_table_storage.cc index 3e8ef7f3ccbf..bdd5daec70b0 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.cc +++ b/src/ray/gcs/gcs_server/gcs_table_storage.cc @@ -129,7 +129,7 @@ Status GcsTableWithJobId::GetByJobId( values[Key::FromBinary(item.first)].ParseFromString(item.second); } } - return std::move(values); + return values; })); } @@ -144,7 +144,7 @@ Status GcsTableWithJobId::DeleteByJobId(const JobID &job_id, keys.push_back(key); } } - return BatchDelete(keys, callback); + return BatchDelete(keys, std::move(callback)); } template From 5e0a080771e48559c8f31bbe8b984322fb1ccb5a Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Mon, 2 Dec 2024 15:06:40 -0800 Subject: [PATCH 23/26] add Dispatch() method for redis Signed-off-by: Ruiyang Wang --- src/ray/common/asio/postable.h | 18 +++- .../gcs/store_client/redis_store_client.cc | 90 ++++++++++--------- src/ray/gcs/store_client/redis_store_client.h | 7 ++ 3 files changed, 71 insertions(+), 44 deletions(-) diff --git a/src/ray/common/asio/postable.h b/src/ray/common/asio/postable.h index a88f045a4fc9..406a404b8066 100644 --- a/src/ray/common/asio/postable.h +++ b/src/ray/common/asio/postable.h @@ -43,6 +43,11 @@ struct ToPostableHelper> { template using ToPostable = typename internal::ToPostableHelper::type; +/// Postable wraps a std::function and an instrumented_io_context together, ensuring the +/// function can only be Post()ed or Dispatch()ed to that specific io_context. This +/// provides type safety and prevents accidentally running the function on the wrong +/// io_context. +/// template class Postable { public: @@ -59,8 +64,19 @@ class Postable { name); } + template + void Dispatch(const std::string &name, Args &&...args) const { + io_context_.dispatch( + [func = func_, + args_tuple = std::make_tuple(std::forward(args)...)]() mutable { + std::apply(func, std::move(args_tuple)); + }, + name); + } + // OnInvocation - Postable &&OnInvocation(std::function observer) && { + // Adds an observer that will be called on the io_context before the original function. + Postable OnInvocation(std::function observer) && { auto original_func = std::move(func_); func_ = [observer = std::move(observer), func = std::move(original_func)](auto &&...args) { diff --git a/src/ray/gcs/store_client/redis_store_client.cc b/src/ray/gcs/store_client/redis_store_client.cc index 0dd37d01e6e6..4d42229469bb 100644 --- a/src/ray/gcs/store_client/redis_store_client.cc +++ b/src/ray/gcs/store_client/redis_store_client.cc @@ -71,6 +71,44 @@ RedisMatchPattern RedisMatchPattern::Prefix(const std::string &prefix) { return RedisMatchPattern(absl::StrCat(EscapeMatchPattern(prefix), "*")); } +void RedisStoreClient::MGetValues( + const std::string &table_name, + const std::vector &keys, + Postable)> callback) { + // The `HMGET` command for each shard. + auto batched_commands = GenCommandsBatched( + "HMGET", RedisKey{external_storage_namespace_, table_name}, keys); + auto total_count = batched_commands.size(); + auto finished_count = std::make_shared(0); + auto key_value_map = std::make_shared>(); + + for (auto &command : batched_commands) { + auto mget_callback = [finished_count, + total_count, + // Copies! + args = command.args, + // Copies! + callback, + key_value_map](const std::shared_ptr &reply) { + if (!reply->IsNil()) { + auto value = reply->ReadAsStringArray(); + for (size_t index = 0; index < value.size(); ++index) { + if (value[index].has_value()) { + (*key_value_map)[args[index]] = *(value[index]); + } + } + } + + ++(*finished_count); + if (*finished_count == total_count) { + callback.Dispatch("RedisStoreClient.AsyncMultiGet", std::move(*key_value_map)); + } + }; + SendRedisCmdArgsAsKeys(std::move(command), std::move(mget_callback)); + } + return Status::OK(); +} + RedisStoreClient::RedisStoreClient(std::shared_ptr redis_client) : external_storage_namespace_(::RayConfig::instance().external_storage_namespace()), redis_client_(std::move(redis_client)) { @@ -90,7 +128,7 @@ Status RedisStoreClient::AsyncPut(const std::string &table_name, RedisCallback write_callback = [callback = std::move(callback)](const std::shared_ptr &reply) { auto added_num = reply->ReadAsInteger(); - callback.Post("RedisStoreClient.AsyncPut", added_num != 0); + callback.Dispatch("RedisStoreClient.AsyncPut", added_num != 0); }; SendRedisCmdWithKeys({key}, std::move(command), std::move(write_callback)); return Status::OK(); @@ -108,7 +146,7 @@ Status RedisStoreClient::AsyncGet( } RAY_CHECK(!reply->IsError()) << "Failed to get from Redis with status: " << reply->ReadAsStatus(); - callback.Post("RedisStoreClient.AsyncGet", Status::OK(), std::move(result)); + callback.Dispatch("RedisStoreClient.AsyncGet", Status::OK(), std::move(result)); }; RedisCommand command{/*command=*/"HGET", @@ -141,7 +179,7 @@ Status RedisStoreClient::AsyncBatchDelete(const std::string &table_name, const std::vector &keys, Postable callback) { if (keys.empty()) { - callback.Post("RedisStoreClient.AsyncBatchDelete", 0); + callback.Dispatch("RedisStoreClient.AsyncBatchDelete", 0); return Status::OK(); } return DeleteByKeys(table_name, keys, std::move(callback)); @@ -152,45 +190,11 @@ Status RedisStoreClient::AsyncMultiGet( const std::vector &keys, Postable)> callback) { if (keys.empty()) { - callback.Post("RedisStoreClient.AsyncMultiGet", - absl::flat_hash_map{}); + callback.Dispatch("RedisStoreClient.AsyncMultiGet", + absl::flat_hash_map{}); return Status::OK(); } - // HMGET external_storage_namespace@table_name key1 key2 ... - // `keys` are chunked to multiple HMGET commands by - // RAY_maximum_gcs_storage_operation_batch_size. - - // The `HMGET` command for each shard. - auto batched_commands = GenCommandsBatched( - "HMGET", RedisKey{external_storage_namespace_, table_name}, keys); - auto total_count = batched_commands.size(); - auto finished_count = std::make_shared(0); - auto key_value_map = std::make_shared>(); - - for (auto &command : batched_commands) { - auto mget_callback = [finished_count, - total_count, - // Copies! - args = command.args, - // Copies! - callback, - key_value_map](const std::shared_ptr &reply) { - if (!reply->IsNil()) { - auto value = reply->ReadAsStringArray(); - for (size_t index = 0; index < value.size(); ++index) { - if (value[index].has_value()) { - (*key_value_map)[args[index]] = *(value[index]); - } - } - } - - ++(*finished_count); - if (*finished_count == total_count) { - callback.Post("RedisStoreClient.AsyncMultiGet", std::move(*key_value_map)); - } - }; - SendRedisCmdArgsAsKeys(std::move(command), std::move(mget_callback)); - } + MGetValues(table_name, keys, std::move(callback)); return Status::OK(); } @@ -324,7 +328,7 @@ Status RedisStoreClient::DeleteByKeys(const std::string &table, (*num_deleted) += reply->ReadAsInteger(); ++(*finished_count); if (*finished_count == total_count) { - callback.Post("RedisStoreClient.AsyncBatchDelete", *num_deleted); + callback.Dispatch("RedisStoreClient.AsyncBatchDelete", *num_deleted); } }; SendRedisCmdArgsAsKeys(std::move(command), std::move(delete_callback)); @@ -366,7 +370,7 @@ void RedisStoreClient::RedisScanner::Scan() { // we should consider using a reader-writer lock. absl::MutexLock lock(&mutex_); if (!cursor_.has_value()) { - callback_.Post("RedisStoreClient.RedisScanner.Scan", std::move(results_)); + callback_.Dispatch("RedisStoreClient.RedisScanner.Scan", std::move(results_)); self_ref_.reset(); return; } @@ -464,7 +468,7 @@ Status RedisStoreClient::AsyncExists(const std::string &table_name, std::move(command), [callback = std::move(callback)](const std::shared_ptr &reply) { bool exists = reply->ReadAsInteger() > 0; - callback.Post("RedisStoreClient.AsyncExists", exists); + callback.Dispatch("RedisStoreClient.AsyncExists", exists); }); return Status::OK(); } diff --git a/src/ray/gcs/store_client/redis_store_client.h b/src/ray/gcs/store_client/redis_store_client.h index e52e1f77ec2c..617d318a484f 100644 --- a/src/ray/gcs/store_client/redis_store_client.h +++ b/src/ray/gcs/store_client/redis_store_client.h @@ -249,6 +249,13 @@ class RedisStoreClient : public StoreClient { // hence command.args may become empty. void SendRedisCmdArgsAsKeys(RedisCommand command, RedisCallback redis_callback); + // HMGET external_storage_namespace@table_name key1 key2 ... + // `keys` are chunked to multiple HMGET commands by + // RAY_maximum_gcs_storage_operation_batch_size. + void MGetValues(const std::string &table_name, + const std::vector &keys, + Postable)> callback); + std::string external_storage_namespace_; std::shared_ptr redis_client_; absl::Mutex mu_; From 9b57f07997a53f5079831bd924aa6bdb8c1bc2a8 Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Mon, 2 Dec 2024 15:15:35 -0800 Subject: [PATCH 24/26] lint Signed-off-by: Ruiyang Wang --- src/ray/gcs/gcs_server/gcs_node_manager.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.cc b/src/ray/gcs/gcs_server/gcs_node_manager.cc index 7380b7dc0efc..65f2290e9458 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_node_manager.cc @@ -32,7 +32,7 @@ namespace gcs { GcsNodeManager::GcsNodeManager(std::shared_ptr gcs_publisher, std::shared_ptr gcs_table_storage, rpc::NodeManagerClientPool *raylet_client_pool, - instrumented_io_context &io_context, + instrumented_io_context &io_context, const ClusterID &cluster_id) : gcs_publisher_(std::move(gcs_publisher)), gcs_table_storage_(std::move(gcs_table_storage)), From a1eede1559abbe4afd010ce60f836accdfe6b03d Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Mon, 2 Dec 2024 23:22:14 -0800 Subject: [PATCH 25/26] move only Post(), and unit tests Signed-off-by: Ruiyang Wang --- python/ray/_raylet.pyx | 2 +- python/ray/serve/tests/test_callback.py | 2 +- .../ray/gcs/gcs_server/gcs_actor_manager.h | 4 +- src/mock/ray/gcs/gcs_server/gcs_kv_manager.h | 58 ++-- .../ray/gcs/gcs_server/gcs_node_manager.h | 3 +- .../ray/gcs/gcs_server/gcs_resource_manager.h | 6 +- .../gcs/store_client/in_memory_store_client.h | 55 ++-- .../ray/gcs/store_client/redis_store_client.h | 9 +- src/mock/ray/gcs/store_client/store_client.h | 16 +- src/ray/common/asio/postable.h | 17 +- src/ray/gcs/gcs_server/gcs_server.cc | 2 +- src/ray/gcs/gcs_server/gcs_table_storage.h | 8 +- .../gcs_actor_manager_export_event_test.cc | 4 +- .../gcs_job_manager_export_event_test.cc | 13 +- .../gcs_node_manager_export_event_test.cc | 2 +- .../gcs_server/test/gcs_actor_manager_test.cc | 7 +- .../test/gcs_actor_scheduler_mock_test.cc | 17 +- .../test/gcs_actor_scheduler_test.cc | 11 +- .../test/gcs_autoscaler_state_manager_test.cc | 12 +- .../test/gcs_function_manager_test.cc | 6 +- .../gcs_server/test/gcs_job_manager_test.cc | 52 ++-- .../gcs_server/test/gcs_kv_manager_test.cc | 94 +++--- .../gcs_server/test/gcs_node_manager_test.cc | 16 +- .../gcs_placement_group_manager_mock_test.cc | 32 +- .../test/gcs_placement_group_manager_test.cc | 14 +- .../gcs_placement_group_scheduler_test.cc | 11 +- .../test/gcs_resource_manager_test.cc | 2 +- .../gcs_server/test/gcs_server_test_util.h | 10 +- .../test/gcs_table_storage_test_base.h | 10 +- .../test/gcs_worker_manager_test.cc | 6 +- .../test/in_memory_gcs_table_storage_test.cc | 3 +- .../test/redis_gcs_table_storage_test.cc | 3 +- .../test/usage_stats_client_test.cc | 28 +- .../store_client/in_memory_store_client.cc | 21 +- .../gcs/store_client/redis_store_client.cc | 43 ++- .../test/in_memory_store_client_test.cc | 2 +- .../test/observable_store_client_test.cc | 4 +- .../test/redis_store_client_test.cc | 275 ++++++++++-------- .../test/store_client_test_base.h | 50 ++-- 39 files changed, 528 insertions(+), 402 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index d0cb813c6ff5..197cd996e82f 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -5249,7 +5249,7 @@ cdef void async_callback(shared_ptr[CRayObject] obj, user_callback = user_callback_ptr user_callback(result) except Exception: - # Only log the error here because this calllback is called from Cpp + # Only log the error here because this callback is called from Cpp # and Cython will ignore the exception anyway logger.exception(f"failed to run async callback (user func)") finally: diff --git a/python/ray/serve/tests/test_callback.py b/python/ray/serve/tests/test_callback.py index 0bedd26d1d92..4f6a98c8b358 100644 --- a/python/ray/serve/tests/test_callback.py +++ b/python/ray/serve/tests/test_callback.py @@ -216,7 +216,7 @@ def test_http_proxy_return_aribitary_objects(ray_instance): ], indirect=True, ) -def test_http_proxy_calllback_failures(ray_instance, capsys): +def test_http_proxy_callback_failures(ray_instance, capsys): """Test http proxy keeps restarting when callback function fails""" try: diff --git a/src/mock/ray/gcs/gcs_server/gcs_actor_manager.h b/src/mock/ray/gcs/gcs_server/gcs_actor_manager.h index 528d0d1af6d9..dc05012dfd0d 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_actor_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_actor_manager.h @@ -28,7 +28,8 @@ namespace gcs { class MockGcsActorManager : public GcsActorManager { public: MockGcsActorManager(RuntimeEnvManager &runtime_env_manager, - GcsFunctionManager &function_manager) + GcsFunctionManager &function_manager, + instrumented_io_context &io_service) : GcsActorManager( nullptr, nullptr, @@ -36,6 +37,7 @@ class MockGcsActorManager : public GcsActorManager { runtime_env_manager, function_manager, [](const ActorID &) {}, + io_service, [](const rpc::Address &) { return nullptr; }) {} MOCK_METHOD(void, diff --git a/src/mock/ray/gcs/gcs_server/gcs_kv_manager.h b/src/mock/ray/gcs/gcs_server/gcs_kv_manager.h index 87abe2101555..25e6e381b451 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_kv_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_kv_manager.h @@ -20,47 +20,46 @@ namespace gcs { class MockInternalKVInterface : public ray::gcs::InternalKVInterface { public: - MockInternalKVInterface() {} + explicit MockInternalKVInterface() = default; MOCK_METHOD(void, Get, (const std::string &ns, const std::string &key, - std::function)> callback), + Postable)> callback), + (override)); + MOCK_METHOD(void, + MultiGet, + (const std::string &ns, + const std::vector &keys, + Postable)> callback), (override)); - MOCK_METHOD( - void, - MultiGet, - (const std::string &ns, - const std::vector &keys, - std::function)> callback), - (override)); MOCK_METHOD(void, Put, (const std::string &ns, const std::string &key, const std::string &value, bool overwrite, - std::function callback), + Postable callback), (override)); MOCK_METHOD(void, Del, (const std::string &ns, const std::string &key, bool del_by_prefix, - std::function callback), + Postable callback), (override)); MOCK_METHOD(void, Exists, (const std::string &ns, const std::string &key, - std::function callback), + Postable callback), (override)); MOCK_METHOD(void, Keys, (const std::string &ns, const std::string &prefix, - std::function)> callback), + Postable)> callback), (override)); }; @@ -68,30 +67,31 @@ class MockInternalKVInterface : public ray::gcs::InternalKVInterface { // Only supports Put and Get. // Warning: Naively prepends the namespace to the key, so e.g. // the (namespace, key) pairs ("a", "bc") and ("ab", "c") will collide which is a bug. - +// TODO(ryw): DO NOT SUBMIT. Get and Put used to be sync, now it's async. We need to use +// promise wait. class FakeInternalKVInterface : public ray::gcs::InternalKVInterface { public: - FakeInternalKVInterface() {} + explicit FakeInternalKVInterface() = default; // The C++ map. std::unordered_map kv_store_ = {}; void Get(const std::string &ns, const std::string &key, - std::function)> callback) override { + Postable)> callback) override { std::string full_key = ns + key; auto it = kv_store_.find(full_key); if (it == kv_store_.end()) { - callback(std::nullopt); + std::move(callback).Post("FakeInternalKVInterface::Get", std::nullopt); } else { - callback(it->second); + std::move(callback).Post("FakeInternalKVInterface::Get", it->second); } } - void MultiGet(const std::string &ns, - const std::vector &keys, - std::function)> - callback) override { + void MultiGet( + const std::string &ns, + const std::vector &keys, + Postable)> callback) override { std::unordered_map result; for (const auto &key : keys) { std::string full_key = ns + key; @@ -100,20 +100,20 @@ class FakeInternalKVInterface : public ray::gcs::InternalKVInterface { result[key] = it->second; } } - callback(result); + std::move(callback).Post("FakeInternalKVInterface::MultiGet", result); } void Put(const std::string &ns, const std::string &key, const std::string &value, bool overwrite, - std::function callback) override { + Postable callback) override { std::string full_key = ns + key; if (kv_store_.find(full_key) != kv_store_.end() && !overwrite) { - callback(false); + std::move(callback).Post("FakeInternalKVInterface::Put", false); } else { kv_store_[full_key] = value; - callback(true); + std::move(callback).Post("FakeInternalKVInterface::Put", true); } } @@ -122,19 +122,19 @@ class FakeInternalKVInterface : public ray::gcs::InternalKVInterface { (const std::string &ns, const std::string &key, bool del_by_prefix, - std::function callback), + Postable callback), (override)); MOCK_METHOD(void, Exists, (const std::string &ns, const std::string &key, - std::function callback), + Postable callback), (override)); MOCK_METHOD(void, Keys, (const std::string &ns, const std::string &prefix, - std::function)> callback), + Postable)> callback), (override)); }; diff --git a/src/mock/ray/gcs/gcs_server/gcs_node_manager.h b/src/mock/ray/gcs/gcs_server/gcs_node_manager.h index 7a3efe197529..7828eabdc1de 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_node_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_node_manager.h @@ -18,10 +18,11 @@ namespace gcs { class MockGcsNodeManager : public GcsNodeManager { public: - MockGcsNodeManager() + MockGcsNodeManager(instrumented_io_context &io_context) : GcsNodeManager(/*gcs_publisher=*/nullptr, /*gcs_table_storage=*/nullptr, /*raylet_client_pool=*/nullptr, + io_context, /*cluster_id=*/ClusterID::Nil()) {} MOCK_METHOD(void, HandleRegisterNode, diff --git a/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h b/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h index 2464bc7cf9b1..920baba8c3bd 100644 --- a/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h +++ b/src/mock/ray/gcs/gcs_server/gcs_resource_manager.h @@ -18,10 +18,8 @@ namespace ray { namespace gcs { static instrumented_io_context __mock_io_context_; static ClusterResourceManager __mock_cluster_resource_manager_(__mock_io_context_); -static GcsNodeManager __mock_gcs_node_manager_(nullptr, - nullptr, - nullptr, - ClusterID::Nil()); +static GcsNodeManager __mock_gcs_node_manager_( + nullptr, nullptr, nullptr, __mock_io_context_, ClusterID::Nil()); class MockGcsResourceManager : public GcsResourceManager { public: diff --git a/src/mock/ray/gcs/store_client/in_memory_store_client.h b/src/mock/ray/gcs/store_client/in_memory_store_client.h index 78201a61b121..7b8c945eaa61 100644 --- a/src/mock/ray/gcs/store_client/in_memory_store_client.h +++ b/src/mock/ray/gcs/store_client/in_memory_store_client.h @@ -22,65 +22,58 @@ class MockInMemoryStoreClient : public InMemoryStoreClient { (const std::string &table_name, const std::string &key, const std::string &data, - const StatusCallback &callback), - (override)); - MOCK_METHOD(Status, - AsyncPutWithIndex, - (const std::string &table_name, - const std::string &key, - const std::string &index_key, - const std::string &data, - const StatusCallback &callback), + bool overwrite, + Postable callback), (override)); + MOCK_METHOD(Status, AsyncGet, (const std::string &table_name, const std::string &key, - const OptionalItemCallback &callback), - (override)); - MOCK_METHOD(Status, - AsyncGetByIndex, - (const std::string &table_name, - const std::string &index_key, - (const MapCallback &callback)), + ToPostable> callback), (override)); + MOCK_METHOD(Status, AsyncGetAll, (const std::string &table_name, - (const MapCallback &callback)), + Postable)> callback), (override)); + MOCK_METHOD(Status, - AsyncDelete, + AsyncMultiGet, (const std::string &table_name, - const std::string &key, - const StatusCallback &callback), + const std::vector &keys, + Postable)> callback), (override)); + MOCK_METHOD(Status, - AsyncDeleteWithIndex, + AsyncDelete, (const std::string &table_name, const std::string &key, - const std::string &index_key, - const StatusCallback &callback), + Postable callback), (override)); + MOCK_METHOD(Status, AsyncBatchDelete, (const std::string &table_name, const std::vector &keys, - const StatusCallback &callback), + Postable callback), (override)); + MOCK_METHOD(Status, - AsyncBatchDeleteWithIndex, + AsyncGetKeys, (const std::string &table_name, - const std::vector &keys, - const std::vector &index_keys, - const StatusCallback &callback), + const std::string &prefix, + Postable)> callback), (override)); + MOCK_METHOD(Status, - AsyncDeleteByIndex, + AsyncExists, (const std::string &table_name, - const std::string &index_key, - const StatusCallback &callback), + const std::string &key, + Postable callback), (override)); + MOCK_METHOD(int, GetNextJobID, (), (override)); }; diff --git a/src/mock/ray/gcs/store_client/redis_store_client.h b/src/mock/ray/gcs/store_client/redis_store_client.h index 5401625e8807..a473c695f937 100644 --- a/src/mock/ray/gcs/store_client/redis_store_client.h +++ b/src/mock/ray/gcs/store_client/redis_store_client.h @@ -25,14 +25,7 @@ class MockRedisStoreClient : public RedisStoreClient { const std::string &data, const StatusCallback &callback), (override)); - MOCK_METHOD(Status, - AsyncPutWithIndex, - (const std::string &table_name, - const std::string &key, - const std::string &index_key, - const std::string &data, - const StatusCallback &callback), - (override)); + MOCK_METHOD(Status, AsyncGet, (const std::string &table_name, diff --git a/src/mock/ray/gcs/store_client/store_client.h b/src/mock/ray/gcs/store_client/store_client.h index 90a1935fbb11..e3feb1786f0b 100644 --- a/src/mock/ray/gcs/store_client/store_client.h +++ b/src/mock/ray/gcs/store_client/store_client.h @@ -23,50 +23,50 @@ class MockStoreClient : public StoreClient { const std::string &key, const std::string &data, bool overwrite, - std::function callback), + Postable callback), (override)); MOCK_METHOD(Status, AsyncGet, (const std::string &table_name, const std::string &key, - const OptionalItemCallback &callback), + ToPostable> callback), (override)); MOCK_METHOD(Status, AsyncGetAll, (const std::string &table_name, - (const MapCallback &callback)), + Postable)> callback), (override)); MOCK_METHOD(Status, AsyncMultiGet, (const std::string &table_name, const std::vector &key, - (const MapCallback &callback)), + Postable)> callback), (override)); MOCK_METHOD(Status, AsyncDelete, (const std::string &table_name, const std::string &key, - std::function callback), + Postable callback), (override)); MOCK_METHOD(Status, AsyncBatchDelete, (const std::string &table_name, const std::vector &keys, - std::function callback), + Postable callback), (override)); MOCK_METHOD(int, GetNextJobID, (), (override)); MOCK_METHOD(Status, AsyncGetKeys, (const std::string &table_name, const std::string &prefix, - std::function)> callback), + Postable)> callback), (override)); MOCK_METHOD(Status, AsyncExists, (const std::string &table_name, const std::string &key, - std::function callback), + Postable callback), (override)); }; diff --git a/src/ray/common/asio/postable.h b/src/ray/common/asio/postable.h index 406a404b8066..fd7a36532424 100644 --- a/src/ray/common/asio/postable.h +++ b/src/ray/common/asio/postable.h @@ -48,16 +48,22 @@ using ToPostable = typename internal::ToPostableHelper::type; /// provides type safety and prevents accidentally running the function on the wrong /// io_context. /// +/// A Postable can only be Post()ed or Dispatch()ed once. After that, it is moved-from and +/// a next invocation will fail. template class Postable { public: Postable(std::function func, instrumented_io_context &io_context) - : func_(std::move(func)), io_context_(io_context) {} + : func_(std::move(func)), io_context_(io_context) { + RAY_CHECK(func_ != nullptr) + << "Postable must be constructed with a non-null function."; + } template - void Post(const std::string &name, Args &&...args) const { + void Post(const std::string &name, Args &&...args) && { + RAY_CHECK(func_ != nullptr) << "Postable has already been invoked."; io_context_.post( - [func = func_, + [func = std::move(func_), args_tuple = std::make_tuple(std::forward(args)...)]() mutable { std::apply(func, std::move(args_tuple)); }, @@ -65,9 +71,10 @@ class Postable { } template - void Dispatch(const std::string &name, Args &&...args) const { + void Dispatch(const std::string &name, Args &&...args) && { + RAY_CHECK(func_ != nullptr) << "Postable has already been invoked."; io_context_.dispatch( - [func = func_, + [func = std::move(func_), args_tuple = std::make_tuple(std::forward(args)...)]() mutable { std::apply(func, std::move(args_tuple)); }, diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 581c6607c174..9794704ef69f 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -76,7 +76,7 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, auto &io_context = io_context_provider_.GetDefaultIOContext(); switch (storage_type_) { case StorageType::IN_MEMORY: - gcs_table_storage_ = std::make_shared(io_context); + gcs_table_storage_ = std::make_shared(); break; case StorageType::REDIS_PERSIST: { auto redis_client = CreateRedisClient(io_context); diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.h b/src/ray/gcs/gcs_server/gcs_table_storage.h index 43f50f62ad7d..6d1359647dac 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.h +++ b/src/ray/gcs/gcs_server/gcs_table_storage.h @@ -304,13 +304,9 @@ class RedisGcsTableStorage : public GcsTableStorage { /// that uses memory as storage. class InMemoryGcsTableStorage : public GcsTableStorage { public: - explicit InMemoryGcsTableStorage(instrumented_io_context &main_io_service) + explicit InMemoryGcsTableStorage() : GcsTableStorage(std::make_shared( - std::make_unique())), - io_service_(main_io_service) {} - - // All methods are posted to this io_service. - instrumented_io_context &io_service_; + std::make_unique())) {} }; } // namespace gcs diff --git a/src/ray/gcs/gcs_server/test/export_api/gcs_actor_manager_export_event_test.cc b/src/ray/gcs/gcs_server/test/export_api/gcs_actor_manager_export_event_test.cc index ce6955fed38f..409a24834d4a 100644 --- a/src/ray/gcs/gcs_server/test/export_api/gcs_actor_manager_export_event_test.cc +++ b/src/ray/gcs/gcs_server/test/export_api/gcs_actor_manager_export_event_test.cc @@ -148,8 +148,8 @@ class GcsActorManagerTest : public ::testing::Test { /*batch_size=*/100); gcs_publisher_ = std::make_shared(std::move(publisher)); - store_client_ = std::make_shared(io_service_); - gcs_table_storage_ = std::make_shared(io_service_); + store_client_ = std::make_shared(); + gcs_table_storage_ = std::make_shared(); kv_ = std::make_unique(); function_manager_ = std::make_unique(*kv_); gcs_actor_manager_ = std::make_unique( diff --git a/src/ray/gcs/gcs_server/test/export_api/gcs_job_manager_export_event_test.cc b/src/ray/gcs/gcs_server/test/export_api/gcs_job_manager_export_event_test.cc index 832490d2f81e..7fc60cd588ee 100644 --- a/src/ray/gcs/gcs_server/test/export_api/gcs_job_manager_export_event_test.cc +++ b/src/ray/gcs/gcs_server/test/export_api/gcs_job_manager_export_event_test.cc @@ -34,12 +34,6 @@ using json = nlohmann::json; namespace ray { -class MockInMemoryStoreClient : public gcs::InMemoryStoreClient { - public: - explicit MockInMemoryStoreClient(instrumented_io_context &main_io_service) - : gcs::InMemoryStoreClient(main_io_service) {} -}; - class GcsJobManagerTest : public ::testing::Test { public: GcsJobManagerTest() : runtime_env_manager_(nullptr) { @@ -54,11 +48,11 @@ class GcsJobManagerTest : public ::testing::Test { gcs_publisher_ = std::make_shared( std::make_unique()); - store_client_ = std::make_shared(io_service_); + store_client_ = std::make_shared(); gcs_table_storage_ = std::make_shared(store_client_); kv_ = std::make_unique(); fake_kv_ = std::make_unique(); - function_manager_ = std::make_unique(*kv_); + function_manager_ = std::make_unique(*kv_, io_service_); // Mock client factory which abuses the "address" argument to return a // CoreWorkerClient whose number of running tasks equal to the address port. This is @@ -112,10 +106,11 @@ TEST_F(GcsJobManagerTest, TestExportDriverJobEvents) { runtime_env_manager_, *function_manager_, *fake_kv_, + io_service_, client_factory_); gcs::GcsInitData gcs_init_data(gcs_table_storage_); - gcs_job_manager.Initialize(/*init_data=*/gcs_init_data); + gcs_job_manager.Initialize(/*gcs_init_data=*/gcs_init_data); auto job_api_job_id = JobID::FromInt(100); std::string submission_id = "submission_id_100"; diff --git a/src/ray/gcs/gcs_server/test/export_api/gcs_node_manager_export_event_test.cc b/src/ray/gcs/gcs_server/test/export_api/gcs_node_manager_export_event_test.cc index 61d2d0e8b932..78c72e18c0ab 100644 --- a/src/ray/gcs/gcs_server/test/export_api/gcs_node_manager_export_event_test.cc +++ b/src/ray/gcs/gcs_server/test/export_api/gcs_node_manager_export_event_test.cc @@ -45,7 +45,7 @@ class GcsNodeManagerExportAPITest : public ::testing::Test { [this](const rpc::Address &) { return raylet_client_; }); gcs_publisher_ = std::make_shared( std::make_unique()); - gcs_table_storage_ = std::make_shared(io_service_); + gcs_table_storage_ = std::make_shared(); RayConfig::instance().initialize( R"( diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc index 9bb274af97bd..9ab7380708f0 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc @@ -143,10 +143,10 @@ class GcsActorManagerTest : public ::testing::Test { /*batch_size=*/100); gcs_publisher_ = std::make_shared(std::move(publisher)); - store_client_ = std::make_shared(io_service_); - gcs_table_storage_ = std::make_shared(io_service_); + store_client_ = std::make_shared(); + gcs_table_storage_ = std::make_shared(); kv_ = std::make_unique(); - function_manager_ = std::make_unique(*kv_); + function_manager_ = std::make_unique(*kv_, io_service_); gcs_actor_manager_ = std::make_unique( mock_actor_scheduler_, gcs_table_storage_, @@ -154,6 +154,7 @@ class GcsActorManagerTest : public ::testing::Test { *runtime_env_mgr_, *function_manager_, [](const ActorID &actor_id) {}, + io_service_, [this](const rpc::Address &addr) { return worker_client_; }); for (int i = 1; i <= 10; i++) { diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc index aca66ca39c09..009b3765dad6 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_mock_test.cc @@ -38,8 +38,8 @@ class GcsActorSchedulerMockTest : public Test { void SetUp() override { store_client = std::make_shared(); actor_table = std::make_unique(store_client); - gcs_node_manager = - std::make_unique(nullptr, nullptr, nullptr, ClusterID::Nil()); + gcs_node_manager = std::make_unique( + nullptr, nullptr, nullptr, io_context, ClusterID::Nil()); raylet_client = std::make_shared(); core_worker_client = std::make_shared(); client_pool = std::make_unique( @@ -146,10 +146,16 @@ TEST_F(GcsActorSchedulerMockTest, KillWorkerLeak2) { EXPECT_CALL(*raylet_client, RequestWorkerLease(An(), _, _, _, _)) .WillOnce(testing::SaveArg<2>(&request_worker_lease_cb)); - std::function async_put_with_index_cb; + // Init a unique_ptr to receive the callback + std::unique_ptr> async_put_with_index_cb; // Leasing successfully EXPECT_CALL(*store_client, AsyncPut(_, _, _, _, _)) - .WillOnce(DoAll(SaveArg<4>(&async_put_with_index_cb), Return(Status::OK()))); + .WillOnce(DoAll(Invoke([&async_put_with_index_cb]( + auto, auto, auto, auto, Postable cb) { + async_put_with_index_cb = + std::make_unique>(std::move(cb)); + }), + Return(Status::OK()))); actor_scheduler->ScheduleByRaylet(actor); rpc::RequestWorkerLeaseReply reply; reply.mutable_worker_address()->set_raylet_id(node_id.Binary()); @@ -160,7 +166,8 @@ TEST_F(GcsActorSchedulerMockTest, KillWorkerLeak2) { // Worker start to run task EXPECT_CALL(*core_worker_client, PushNormalTask(_, _)) .WillOnce(testing::SaveArg<1>(&push_normal_task_cb)); - async_put_with_index_cb(true); + std::move(*async_put_with_index_cb) + .Post("GcsActorSchedulerMockTest::KillWorkerLeak2", true); actor->GetMutableActorTableData()->set_state(rpc::ActorTableData::DEAD); actor_scheduler->CancelOnWorker(node_id, worker_id); push_normal_task_cb(Status::OK(), rpc::PushTaskReply()); diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc index 6302ee02ed63..f24f54670e55 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc @@ -37,10 +37,13 @@ class GcsActorSchedulerTest : public ::testing::Test { worker_client_ = std::make_shared(); gcs_publisher_ = std::make_shared( std::make_unique()); - store_client_ = std::make_shared(io_service_); - gcs_table_storage_ = std::make_shared(io_service_); - gcs_node_manager_ = std::make_shared( - gcs_publisher_, gcs_table_storage_, raylet_client_pool_.get(), ClusterID::Nil()); + store_client_ = std::make_shared(); + gcs_table_storage_ = std::make_shared(); + gcs_node_manager_ = std::make_shared(gcs_publisher_, + gcs_table_storage_, + raylet_client_pool_.get(), + io_service_, + ClusterID::Nil()); gcs_actor_table_ = std::make_shared(store_client_); local_node_id_ = NodeID::FromRandom(); diff --git a/src/ray/gcs/gcs_server/test/gcs_autoscaler_state_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_autoscaler_state_manager_test.cc index 2f281fa31844..5ea3b58e12af 100644 --- a/src/ray/gcs/gcs_server/test/gcs_autoscaler_state_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_autoscaler_state_manager_test.cc @@ -64,15 +64,17 @@ class GcsAutoscalerStateManagerTest : public ::testing::Test { client_pool_ = std::make_unique( [this](const rpc::Address &) { return raylet_client_; }); cluster_resource_manager_ = std::make_unique(io_service_); - gcs_node_manager_ = std::make_shared(); + gcs_node_manager_ = std::make_shared(io_service_); kv_manager_ = std::make_unique( std::make_unique(std::make_unique()), - kRayletConfig); - function_manager_ = std::make_unique(kv_manager_->GetInstance()); + kRayletConfig, + io_service_); + function_manager_ = + std::make_unique(kv_manager_->GetInstance(), io_service_); runtime_env_manager_ = std::make_unique( [](const std::string &, std::function) {}); - gcs_actor_manager_ = - std::make_unique(*runtime_env_manager_, *function_manager_); + gcs_actor_manager_ = std::make_unique( + *runtime_env_manager_, *function_manager_, io_service_); gcs_resource_manager_ = std::make_shared(io_service_, *cluster_resource_manager_, diff --git a/src/ray/gcs/gcs_server/test/gcs_function_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_function_manager_test.cc index 775720c398a3..50e7ab6ea019 100644 --- a/src/ray/gcs/gcs_server/test/gcs_function_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_function_manager_test.cc @@ -18,6 +18,8 @@ #include "ray/gcs/gcs_server/gcs_function_manager.h" #include "ray/gcs/gcs_server/gcs_kv_manager.h" #include "mock/ray/gcs/gcs_server/gcs_kv_manager.h" +#include "ray/common/asio/asio_util.h" + // clang-format on using namespace ::testing; using namespace ray::gcs; @@ -27,10 +29,12 @@ class GcsFunctionManagerTest : public Test { public: void SetUp() override { kv = std::make_unique(); - function_manager = std::make_unique(*kv); + function_manager = + std::make_unique(*kv, io_context_.GetIoService()); } std::unique_ptr function_manager; std::unique_ptr kv; + InstrumentedIOContextWithThread io_context_{"GcsFunctionManagerTest"}; }; TEST_F(GcsFunctionManagerTest, TestFunctionManagerGC) { diff --git a/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc index b18658dffc95..c5e9d4c13648 100644 --- a/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc @@ -32,12 +32,6 @@ namespace ray { -class MockInMemoryStoreClient : public gcs::InMemoryStoreClient { - public: - explicit MockInMemoryStoreClient(instrumented_io_context &main_io_service) - : gcs::InMemoryStoreClient(main_io_service) {} -}; - class GcsJobManagerTest : public ::testing::Test { public: GcsJobManagerTest() : runtime_env_manager_(nullptr) { @@ -52,11 +46,11 @@ class GcsJobManagerTest : public ::testing::Test { gcs_publisher_ = std::make_shared( std::make_unique()); - store_client_ = std::make_shared(io_service_); + store_client_ = std::make_shared(); gcs_table_storage_ = std::make_shared(store_client_); kv_ = std::make_unique(); fake_kv_ = std::make_unique(); - function_manager_ = std::make_unique(*kv_); + function_manager_ = std::make_unique(*kv_, io_service_); // Mock client factory which abuses the "address" argument to return a // CoreWorkerClient whose number of running tasks equal to the address port. This is @@ -87,18 +81,23 @@ class GcsJobManagerTest : public ::testing::Test { }; TEST_F(GcsJobManagerTest, TestFakeInternalKV) { - fake_kv_->Put("ns", "key", "value", /*overwrite=*/true, /*callback=*/[](auto) {}); + fake_kv_->Put("ns", "key", "value", /*overwrite=*/true, {[](auto) {}, io_service_}); + fake_kv_->Get( - "ns", "key", [](std::optional v) { ASSERT_EQ(v.value(), "value"); }); - fake_kv_->Put("ns", "key2", "value2", /*overwrite=*/true, /*callback=*/[](auto) {}); + "ns", + "key", + {[](std::optional v) { ASSERT_EQ(v.value(), "value"); }, io_service_}); + + fake_kv_->Put("ns", "key2", "value2", /*overwrite=*/true, {[](auto) {}, io_service_}); fake_kv_->MultiGet("ns", {"key", "key2"}, - [](const std::unordered_map &result) { - ASSERT_EQ(result.size(), 2); - ASSERT_EQ(result.at("key"), "value"); - ASSERT_EQ(result.at("key2"), "value2"); - }); + {[](const std::unordered_map &result) { + ASSERT_EQ(result.size(), 2); + ASSERT_EQ(result.at("key"), "value"); + ASSERT_EQ(result.at("key2"), "value2"); + }, + io_service_}); } TEST_F(GcsJobManagerTest, TestIsRunningTasks) { @@ -107,10 +106,11 @@ TEST_F(GcsJobManagerTest, TestIsRunningTasks) { runtime_env_manager_, *function_manager_, *fake_kv_, + io_service_, client_factory_); gcs::GcsInitData gcs_init_data(gcs_table_storage_); - gcs_job_manager.Initialize(/*init_data=*/gcs_init_data); + gcs_job_manager.Initialize(gcs_init_data); // Add 100 jobs. Job i should have i running tasks. int num_jobs = 100; @@ -171,10 +171,11 @@ TEST_F(GcsJobManagerTest, TestGetAllJobInfo) { runtime_env_manager_, *function_manager_, *fake_kv_, + io_service_, client_factory_); gcs::GcsInitData gcs_init_data(gcs_table_storage_); - gcs_job_manager.Initialize(/*init_data=*/gcs_init_data); + gcs_job_manager.Initialize(gcs_init_data); // Add 100 jobs. for (int i = 0; i < 100; ++i) { @@ -243,7 +244,7 @@ TEST_F(GcsJobManagerTest, TestGetAllJobInfo) { gcs::JobDataKey(submission_id), job_info_json, /*overwrite=*/true, - [&kv_promise](auto) { kv_promise.set_value(true); }); + {[&kv_promise](auto) { kv_promise.set_value(true); }, io_service_}); kv_promise.get_future().get(); // Get all job info again. @@ -294,7 +295,7 @@ TEST_F(GcsJobManagerTest, TestGetAllJobInfo) { gcs::JobDataKey(submission_id), job_info_json, /*overwrite=*/true, - [&kv_promise2](auto) { kv_promise2.set_value(true); }); + {[&kv_promise2](auto) { kv_promise2.set_value(true); }, io_service_}); kv_promise2.get_future().get(); // Get all job info again. @@ -348,12 +349,13 @@ TEST_F(GcsJobManagerTest, TestGetAllJobInfoWithFilter) { runtime_env_manager_, *function_manager_, *fake_kv_, + io_service_, client_factory_); auto job_id1 = JobID::FromInt(1); auto job_id2 = JobID::FromInt(2); gcs::GcsInitData gcs_init_data(gcs_table_storage_); - gcs_job_manager.Initialize(/*init_data=*/gcs_init_data); + gcs_job_manager.Initialize(gcs_init_data); rpc::AddJobReply empty_reply; std::promise promise1; @@ -433,6 +435,7 @@ TEST_F(GcsJobManagerTest, TestGetAllJobInfoWithLimit) { runtime_env_manager_, *function_manager_, *fake_kv_, + io_service_, client_factory_); auto job_id1 = JobID::FromInt(1); @@ -536,12 +539,13 @@ TEST_F(GcsJobManagerTest, TestGetJobConfig) { runtime_env_manager_, *function_manager_, *kv_, + io_service_, client_factory_); auto job_id1 = JobID::FromInt(1); auto job_id2 = JobID::FromInt(2); gcs::GcsInitData gcs_init_data(gcs_table_storage_); - gcs_job_manager.Initialize(/*init_data=*/gcs_init_data); + gcs_job_manager.Initialize(gcs_init_data); rpc::AddJobReply empty_reply; std::promise promise1; @@ -578,11 +582,12 @@ TEST_F(GcsJobManagerTest, TestPreserveDriverInfo) { runtime_env_manager_, *function_manager_, *fake_kv_, + io_service_, client_factory_); auto job_id = JobID::FromInt(1); gcs::GcsInitData gcs_init_data(gcs_table_storage_); - gcs_job_manager.Initialize(/*init_data=*/gcs_init_data); + gcs_job_manager.Initialize(gcs_init_data); auto add_job_request = Mocker::GenAddJobRequest(job_id, "namespace"); rpc::Address address; @@ -645,6 +650,7 @@ TEST_F(GcsJobManagerTest, TestNodeFailure) { runtime_env_manager_, *function_manager_, *fake_kv_, + io_service_, client_factory_); auto job_id1 = JobID::FromInt(1); diff --git a/src/ray/gcs/gcs_server/test/gcs_kv_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_kv_manager_test.cc index 9c09658e16c6..6b97502c3dd0 100644 --- a/src/ray/gcs/gcs_server/test/gcs_kv_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_kv_manager_test.cc @@ -40,7 +40,7 @@ class GcsKVManagerTest : public ::testing::TestWithParam { std::make_unique(client)); } else if (GetParam() == "memory") { kv_instance = std::make_unique( - std::make_unique(io_service)); + std::make_unique()); } } @@ -58,62 +58,82 @@ class GcsKVManagerTest : public ::testing::TestWithParam { }; TEST_P(GcsKVManagerTest, TestInternalKV) { - kv_instance->Get("N1", "A", [](auto b) { ASSERT_FALSE(b.has_value()); }); - kv_instance->Put("N1", "A", "B", false, [](auto b) { ASSERT_TRUE(b); }); - kv_instance->Put("N1", "A", "C", false, [](auto b) { ASSERT_FALSE(b); }); - kv_instance->Get("N1", "A", [](auto b) { ASSERT_EQ("B", *b); }); - kv_instance->Put("N1", "A", "C", true, [](auto b) { ASSERT_FALSE(b); }); - kv_instance->Get("N1", "A", [](auto b) { ASSERT_EQ("C", *b); }); - kv_instance->Put("N1", "A_1", "B", false, [](auto b) { ASSERT_TRUE(b); }); - kv_instance->Put("N1", "A_2", "C", false, [](auto b) { ASSERT_TRUE(b); }); - kv_instance->Put("N1", "A_3", "C", false, [](auto b) { ASSERT_TRUE(b); }); - kv_instance->Keys("N1", "A_", [](std::vector keys) { - auto expected = std::set{"A_1", "A_2", "A_3"}; - ASSERT_EQ(expected, std::set(keys.begin(), keys.end())); - }); - kv_instance->Get("N2", "A_1", [](auto b) { ASSERT_FALSE(b.has_value()); }); - kv_instance->Get("N1", "A_1", [](auto b) { ASSERT_TRUE(b.has_value()); }); - kv_instance->MultiGet("N1", {"A_1", "A_2", "A_3"}, [](auto b) { - ASSERT_EQ(3, b.size()); - ASSERT_EQ("B", b["A_1"]); - ASSERT_EQ("C", b["A_2"]); - ASSERT_EQ("C", b["A_3"]); - }); + kv_instance->Get("N1", "A", {[](auto b) { ASSERT_FALSE(b.has_value()); }, io_service}); + kv_instance->Put("N1", "A", "B", false, {[](auto b) { ASSERT_TRUE(b); }, io_service}); + kv_instance->Put("N1", "A", "C", false, {[](auto b) { ASSERT_FALSE(b); }, io_service}); + kv_instance->Get("N1", "A", {[](auto b) { ASSERT_EQ("B", *b); }, io_service}); + kv_instance->Put("N1", "A", "C", true, {[](auto b) { ASSERT_FALSE(b); }, io_service}); + kv_instance->Get("N1", "A", {[](auto b) { ASSERT_EQ("C", *b); }, io_service}); + kv_instance->Put("N1", "A_1", "B", false, {[](auto b) { ASSERT_TRUE(b); }, io_service}); + kv_instance->Put("N1", "A_2", "C", false, {[](auto b) { ASSERT_TRUE(b); }, io_service}); + kv_instance->Put("N1", "A_3", "C", false, {[](auto b) { ASSERT_TRUE(b); }, io_service}); + kv_instance->Keys("N1", + "A_", + {[](std::vector keys) { + auto expected = std::set{"A_1", "A_2", "A_3"}; + ASSERT_EQ(expected, + std::set(keys.begin(), keys.end())); + }, + io_service}); + kv_instance->Get( + "N2", "A_1", {[](auto b) { ASSERT_FALSE(b.has_value()); }, io_service}); + kv_instance->Get("N1", "A_1", {[](auto b) { ASSERT_TRUE(b.has_value()); }, io_service}); + kv_instance->MultiGet("N1", + {"A_1", "A_2", "A_3"}, + {[](auto b) { + ASSERT_EQ(3, b.size()); + ASSERT_EQ("B", b["A_1"]); + ASSERT_EQ("C", b["A_2"]); + ASSERT_EQ("C", b["A_3"]); + }, + io_service}); // MultiGet with empty keys. - kv_instance->MultiGet("N1", {}, [](auto b) { ASSERT_EQ(0, b.size()); }); + kv_instance->MultiGet("N1", {}, {[](auto b) { ASSERT_EQ(0, b.size()); }, io_service}); // MultiGet with non-existent keys. - kv_instance->MultiGet("N1", {"A_4", "A_5"}, [](auto b) { ASSERT_EQ(0, b.size()); }); + kv_instance->MultiGet( + "N1", {"A_4", "A_5"}, {[](auto b) { ASSERT_EQ(0, b.size()); }, io_service}); { // Delete by prefix are two steps in redis mode, so we need sync here. std::promise p; - kv_instance->Del("N1", "A_", true, [&p](auto b) { - ASSERT_EQ(3, b); - p.set_value(); - }); + kv_instance->Del("N1", + "A_", + true, + {[&p](auto b) { + ASSERT_EQ(3, b); + p.set_value(); + }, + io_service}); p.get_future().get(); } { // Delete by prefix are two steps in redis mode, so we need sync here. std::promise p; - kv_instance->Del("NX", "A_", true, [&p](auto b) { - ASSERT_EQ(0, b); - p.set_value(); - }); + kv_instance->Del("NX", + "A_", + true, + {[&p](auto b) { + ASSERT_EQ(0, b); + p.set_value(); + }, + io_service}); p.get_future().get(); } { // Make sure the last cb is called. std::promise p; - kv_instance->Get("N1", "A_1", [&p](auto b) { - ASSERT_FALSE(b.has_value()); - p.set_value(); - }); + kv_instance->Get("N1", + "A_1", + {[&p](auto b) { + ASSERT_FALSE(b.has_value()); + p.set_value(); + }, + io_service}); p.get_future().get(); } // Check the keys are deleted. kv_instance->MultiGet( - "N1", {"A_1", "A_2", "A_3"}, [](auto b) { ASSERT_EQ(0, b.size()); }); + "N1", {"A_1", "A_2", "A_3"}, {[](auto b) { ASSERT_EQ(0, b.size()); }, io_service}); } INSTANTIATE_TEST_SUITE_P(GcsKVManagerTestFixture, diff --git a/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc index eb12d59dbdb3..fc894e35ff6c 100644 --- a/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc @@ -21,6 +21,7 @@ #include "ray/rpc/node_manager/node_manager_client.h" #include "ray/rpc/node_manager/node_manager_client_pool.h" #include "mock/ray/pubsub/publisher.h" +#include "ray/common/asio/asio_util.h" // clang-format on namespace ray { @@ -39,11 +40,15 @@ class GcsNodeManagerTest : public ::testing::Test { std::shared_ptr raylet_client_; std::unique_ptr client_pool_; std::shared_ptr gcs_publisher_; + InstrumentedIOContextWithThread io_context_{"GcsNodeManagerTest"}; }; TEST_F(GcsNodeManagerTest, TestManagement) { - gcs::GcsNodeManager node_manager( - gcs_publisher_, gcs_table_storage_, client_pool_.get(), ClusterID::Nil()); + gcs::GcsNodeManager node_manager(gcs_publisher_, + gcs_table_storage_, + client_pool_.get(), + io_context_.GetIoService(), + ClusterID::Nil()); // Test Add/Get/Remove functionality. auto node = Mocker::GenNodeInfo(); auto node_id = NodeID::FromBinary(node->node_id()); @@ -57,8 +62,11 @@ TEST_F(GcsNodeManagerTest, TestManagement) { } TEST_F(GcsNodeManagerTest, TestListener) { - gcs::GcsNodeManager node_manager( - gcs_publisher_, gcs_table_storage_, client_pool_.get(), ClusterID::Nil()); + gcs::GcsNodeManager node_manager(gcs_publisher_, + gcs_table_storage_, + client_pool_.get(), + io_context_.GetIoService(), + ClusterID::Nil()); // Test AddNodeAddedListener. int node_count = 1000; std::vector> added_nodes; diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc index 1e3ef61060c8..15370474f71c 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_mock_test.cc @@ -41,7 +41,7 @@ class GcsPlacementGroupManagerMockTest : public Test { gcs_table_storage_ = std::make_shared(store_client_); gcs_placement_group_scheduler_ = std::make_shared(); - node_manager_ = std::make_unique(); + node_manager_ = std::make_unique(io_context_); resource_manager_ = std::make_shared( io_context_, cluster_resource_manager_, *node_manager_, NodeID::FromRandom()); @@ -73,9 +73,12 @@ TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityReschedule) { auto pg = std::make_shared(req, "", counter_); auto cb = [](Status s) {}; SchedulePgRequest request; - std::function put_cb; + std::unique_ptr> put_cb; EXPECT_CALL(*store_client_, AsyncPut(_, _, _, _, _)) - .WillOnce(DoAll(SaveArg<4>(&put_cb), Return(Status::OK()))); + .WillOnce(DoAll(Invoke([&put_cb](auto, auto, auto, auto, Postable cb) { + put_cb = std::make_unique>(std::move(cb)); + }), + Return(Status::OK()))); EXPECT_CALL(*gcs_placement_group_scheduler_, ScheduleUnplacedBundles(_)) .WillOnce(DoAll(SaveArg<0>(&request))); auto now = absl::GetCurrentTimeNanos(); @@ -84,7 +87,7 @@ TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityReschedule) { ASSERT_EQ(1, pending_queue.size()); ASSERT_LE(now, pending_queue.begin()->first); ASSERT_GE(absl::GetCurrentTimeNanos(), pending_queue.begin()->first); - put_cb(true); + std::move(*put_cb).Post("GcsPlacementGroupManagerMockTest", true); pg->UpdateState(rpc::PlacementGroupTableData::RESCHEDULING); request.failure_callback(pg, true); ASSERT_EQ(1, pending_queue.size()); @@ -99,9 +102,14 @@ TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityFailed) { auto pg = std::make_shared(req, "", counter_); auto cb = [](Status s) {}; SchedulePgRequest request; - std::function put_cb; + std::unique_ptr> put_cb; EXPECT_CALL(*store_client_, AsyncPut(_, _, _, _, _)) - .WillOnce(DoAll(SaveArg<4>(&put_cb), Return(Status::OK()))); + .Times(2) + .WillRepeatedly( + DoAll(Invoke([&put_cb](auto, auto, auto, auto, Postable cb) { + put_cb = std::make_unique>(std::move(cb)); + }), + Return(Status::OK()))); EXPECT_CALL(*gcs_placement_group_scheduler_, ScheduleUnplacedBundles(_)) .Times(2) .WillRepeatedly(DoAll(SaveArg<0>(&request))); @@ -111,7 +119,7 @@ TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityFailed) { ASSERT_EQ(1, pending_queue.size()); ASSERT_LE(now, pending_queue.begin()->first); ASSERT_GE(absl::GetCurrentTimeNanos(), pending_queue.begin()->first); - put_cb(true); + std::move(*put_cb).Post("GcsPlacementGroupManagerMockTest", true); pg->UpdateState(rpc::PlacementGroupTableData::PENDING); now = absl::GetCurrentTimeNanos(); request.failure_callback(pg, true); @@ -155,10 +163,14 @@ TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityOrder) { auto pg2 = std::make_shared(req2, "", counter_); auto cb = [](Status s) {}; SchedulePgRequest request; - std::function put_cb; + std::unique_ptr> put_cb; EXPECT_CALL(*store_client_, AsyncPut(_, _, _, _, _)) .Times(2) - .WillRepeatedly(DoAll(SaveArg<4>(&put_cb), Return(Status::OK()))); + .WillRepeatedly( + DoAll(Invoke([&put_cb](auto, auto, auto, auto, Postable cb) { + put_cb = std::make_unique>(std::move(cb)); + }), + Return(Status::OK()))); EXPECT_CALL(*gcs_placement_group_scheduler_, ScheduleUnplacedBundles(_)) .Times(2) .WillRepeatedly(DoAll(SaveArg<0>(&request))); @@ -166,7 +178,7 @@ TEST_F(GcsPlacementGroupManagerMockTest, PendingQueuePriorityOrder) { gcs_placement_group_manager_->RegisterPlacementGroup(pg2, cb); auto &pending_queue = gcs_placement_group_manager_->pending_placement_groups_; ASSERT_EQ(2, pending_queue.size()); - put_cb(true); + std::move(*put_cb).Post("GcsPlacementGroupManagerMockTest", true); ASSERT_EQ(1, pending_queue.size()); // PG1 is scheduled first, so PG2 is in pending queue ASSERT_EQ(pg2, pending_queue.begin()->second.second); diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc index ad808b644b67..6d3e211b5ce9 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc @@ -83,8 +83,8 @@ class GcsPlacementGroupManagerTest : public ::testing::Test { cluster_resource_manager_(io_service_) { gcs_publisher_ = std::make_shared(std::make_unique()); - gcs_table_storage_ = std::make_shared(io_service_); - gcs_node_manager_ = std::make_shared(); + gcs_table_storage_ = std::make_shared(); + gcs_node_manager_ = std::make_shared(io_service_); gcs_resource_manager_ = std::make_shared( io_service_, cluster_resource_manager_, *gcs_node_manager_, NodeID::FromRandom()); gcs_placement_group_manager_.reset(new gcs::GcsPlacementGroupManager( @@ -150,7 +150,7 @@ class GcsPlacementGroupManagerTest : public ::testing::Test { std::shared_ptr LoadDataFromDataStorage() { auto gcs_init_data = std::make_shared(gcs_table_storage_); std::promise promise; - gcs_init_data->AsyncLoad([&promise] { promise.set_value(); }); + gcs_init_data->AsyncLoad([&promise] { promise.set_value(); }, io_service_); RunIOService(); promise.get_future().get(); return gcs_init_data; @@ -167,9 +167,9 @@ class GcsPlacementGroupManagerTest : public ::testing::Test { protected: std::shared_ptr gcs_table_storage_; + instrumented_io_context io_service_; private: - instrumented_io_context io_service_; ClusterResourceManager cluster_resource_manager_; std::shared_ptr gcs_node_manager_; std::shared_ptr gcs_resource_manager_; @@ -554,7 +554,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestSchedulerReinitializeAfterGcsRestart) { /* cpu_num */ 1.0, /* job_id */ job_id); auto job_table_data = Mocker::GenJobTableData(job_id); - RAY_CHECK_OK(gcs_table_storage_->JobTable().Put(job_id, *job_table_data, nullptr)); + RAY_CHECK_OK(gcs_table_storage_->JobTable().Put( + job_id, *job_table_data, {[](auto) {}, io_service_})); std::atomic registered_placement_group_count(0); RegisterPlacementGroup(request, [®istered_placement_group_count](Status status) { ++registered_placement_group_count; @@ -981,7 +982,8 @@ TEST_F(GcsPlacementGroupManagerTest, TestCheckCreatorJobIsDeadWhenGcsRestart) { /* job_id */ job_id); auto job_table_data = Mocker::GenJobTableData(job_id); job_table_data->set_is_dead(true); - RAY_CHECK_OK(gcs_table_storage_->JobTable().Put(job_id, *job_table_data, nullptr)); + RAY_CHECK_OK(gcs_table_storage_->JobTable().Put( + job_id, *job_table_data, {[](auto) {}, io_service_})); std::atomic registered_placement_group_count(0); RegisterPlacementGroup(request, [®istered_placement_group_count](Status status) { diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc index 093bdaf13fcc..3b5e5c027910 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc @@ -43,7 +43,7 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test { for (int index = 0; index < 3; ++index) { raylet_clients_.push_back(std::make_shared()); } - gcs_table_storage_ = std::make_shared(io_service_); + gcs_table_storage_ = std::make_shared(); gcs_publisher_ = std::make_shared( std::make_unique()); auto local_node_id = NodeID::FromRandom(); @@ -54,14 +54,17 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test { /*is_node_available_fn=*/ [](auto) { return true; }, /*is_local_node_with_raylet=*/false); - gcs_node_manager_ = std::make_shared( - gcs_publisher_, gcs_table_storage_, raylet_client_pool_.get(), ClusterID::Nil()); + gcs_node_manager_ = std::make_shared(gcs_publisher_, + gcs_table_storage_, + raylet_client_pool_.get(), + io_service_, + ClusterID::Nil()); gcs_resource_manager_ = std::make_shared( io_service_, cluster_resource_scheduler_->GetClusterResourceManager(), *gcs_node_manager_, local_node_id); - store_client_ = std::make_shared(io_service_); + store_client_ = std::make_shared(); raylet_client_pool_ = std::make_unique( [this](const rpc::Address &addr) { return raylet_clients_[addr.port()]; }); scheduler_ = std::make_shared( diff --git a/src/ray/gcs/gcs_server/test/gcs_resource_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_resource_manager_test.cc index cc2d3dec33a8..5757fee270e2 100644 --- a/src/ray/gcs/gcs_server/test/gcs_resource_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_resource_manager_test.cc @@ -30,7 +30,7 @@ class GcsResourceManagerTest : public ::testing::Test { public: GcsResourceManagerTest() : cluster_resource_manager_(io_service_), - gcs_node_manager_(std::make_unique()) { + gcs_node_manager_(std::make_unique(io_service_)) { gcs_resource_manager_ = std::make_shared( io_service_, cluster_resource_manager_, *gcs_node_manager_, NodeID::FromRandom()); } diff --git a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h index ce8c685f706d..2d8294bd9f06 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h +++ b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h @@ -309,7 +309,7 @@ struct GcsServerMocker { void ShutdownRaylet( const NodeID &node_id, bool graceful, - const rpc::ClientCallback &callback) override{}; + const rpc::ClientCallback &callback) override {}; void DrainRaylet( const rpc::autoscaler::DrainNodeReason &reason, @@ -322,7 +322,7 @@ struct GcsServerMocker { }; void NotifyGCSRestart( - const rpc::ClientCallback &callback) override{}; + const rpc::ClientCallback &callback) override {}; ~MockRayletClient() {} @@ -404,16 +404,16 @@ struct GcsServerMocker { Status Put(const ActorID &key, const rpc::ActorTableData &value, - const gcs::StatusCallback &callback) override { + Postable callback) override { auto status = Status::OK(); - callback(status); + std::move(callback).Post("MockedGcsActorTable::Put", status); return status; } private: instrumented_io_context main_io_service_; std::shared_ptr store_client_ = - std::make_shared(main_io_service_); + std::make_shared(); }; class MockedNodeInfoAccessor : public gcs::NodeInfoAccessor { diff --git a/src/ray/gcs/gcs_server/test/gcs_table_storage_test_base.h b/src/ray/gcs/gcs_server/test/gcs_table_storage_test_base.h index ad6e860c1c53..e7f4da800905 100644 --- a/src/ray/gcs/gcs_server/test/gcs_table_storage_test_base.h +++ b/src/ray/gcs/gcs_server/test/gcs_table_storage_test_base.h @@ -102,7 +102,7 @@ class GcsTableStorageTestBase : public ::testing::Test { void Put(TABLE &table, const KEY &key, const VALUE &value) { auto on_done = [this](const Status &status) { --pending_count_; }; ++pending_count_; - RAY_CHECK_OK(table.Put(key, value, on_done)); + RAY_CHECK_OK(table.Put(key, value, {on_done, *io_service_pool_->Get()})); WaitPendingDone(); } @@ -121,7 +121,7 @@ class GcsTableStorageTestBase : public ::testing::Test { --pending_count_; }; ++pending_count_; - RAY_CHECK_OK(table.Get(key, on_done)); + RAY_CHECK_OK(table.Get(key, {on_done, *io_service_pool_->Get()})); WaitPendingDone(); return values.size(); } @@ -144,7 +144,7 @@ class GcsTableStorageTestBase : public ::testing::Test { --pending_count_; }; ++pending_count_; - RAY_CHECK_OK(table.GetByJobId(job_id, on_done)); + RAY_CHECK_OK(table.GetByJobId(job_id, {on_done, *io_service_pool_->Get()})); WaitPendingDone(); return values.size(); } @@ -156,7 +156,7 @@ class GcsTableStorageTestBase : public ::testing::Test { --pending_count_; }; ++pending_count_; - RAY_CHECK_OK(table.Delete(key, on_done)); + RAY_CHECK_OK(table.Delete(key, {on_done, *io_service_pool_->Get()})); WaitPendingDone(); } @@ -167,7 +167,7 @@ class GcsTableStorageTestBase : public ::testing::Test { --pending_count_; }; ++pending_count_; - RAY_CHECK_OK(table.BatchDelete(keys, on_done)); + RAY_CHECK_OK(table.BatchDelete(keys, {on_done, *io_service_pool_->Get()})); WaitPendingDone(); } diff --git a/src/ray/gcs/gcs_server/test/gcs_worker_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_worker_manager_test.cc index 37d6a67b7b0d..bb0d39112247 100644 --- a/src/ray/gcs/gcs_server/test/gcs_worker_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_worker_manager_test.cc @@ -33,7 +33,7 @@ class GcsWorkerManagerTest : public Test { GcsWorkerManagerTest() { gcs_publisher_ = std::make_shared(std::make_unique()); - gcs_table_storage_ = std::make_shared(io_service_); + gcs_table_storage_ = std::make_shared(); } void SetUp() override { @@ -45,8 +45,8 @@ class GcsWorkerManagerTest : public Test { new boost::asio::io_service::work(io_service_)); io_service_.run(); }); - worker_manager_ = - std::make_shared(gcs_table_storage_, gcs_publisher_); + worker_manager_ = std::make_shared( + gcs_table_storage_, gcs_publisher_, io_service_); } void TearDown() override { diff --git a/src/ray/gcs/gcs_server/test/in_memory_gcs_table_storage_test.cc b/src/ray/gcs/gcs_server/test/in_memory_gcs_table_storage_test.cc index dba6ddce5922..277657f225ca 100644 --- a/src/ray/gcs/gcs_server/test/in_memory_gcs_table_storage_test.cc +++ b/src/ray/gcs/gcs_server/test/in_memory_gcs_table_storage_test.cc @@ -23,8 +23,7 @@ namespace ray { class InMemoryGcsTableStorageTest : public gcs::GcsTableStorageTestBase { public: void SetUp() override { - gcs_table_storage_ = - std::make_shared(*(io_service_pool_->Get())); + gcs_table_storage_ = std::make_shared(); } }; diff --git a/src/ray/gcs/gcs_server/test/redis_gcs_table_storage_test.cc b/src/ray/gcs/gcs_server/test/redis_gcs_table_storage_test.cc index c6858022b3e3..2b4ec5b4b744 100644 --- a/src/ray/gcs/gcs_server/test/redis_gcs_table_storage_test.cc +++ b/src/ray/gcs/gcs_server/test/redis_gcs_table_storage_test.cc @@ -31,7 +31,8 @@ class RedisGcsTableStorageTest : public gcs::GcsTableStorageTestBase { redis_client_ = std::make_shared(options); RAY_CHECK_OK(redis_client_->Connect(*io_service_pool_->Get())); - gcs_table_storage_ = std::make_shared(redis_client_); + gcs_table_storage_ = std::make_shared( + redis_client_, *io_service_pool_->Get()); } void TearDown() override { redis_client_->Disconnect(); } diff --git a/src/ray/gcs/gcs_server/test/usage_stats_client_test.cc b/src/ray/gcs/gcs_server/test/usage_stats_client_test.cc index 9448f0000b9f..4c765b274526 100644 --- a/src/ray/gcs/gcs_server/test/usage_stats_client_test.cc +++ b/src/ray/gcs/gcs_server/test/usage_stats_client_test.cc @@ -27,21 +27,27 @@ class UsageStatsClientTest : public ::testing::Test { void SetUp() override { fake_kv_ = std::make_unique(); } void TearDown() override { fake_kv_.reset(); } std::unique_ptr fake_kv_; + InstrumentedIOContextWithThread io_context_{"UsageStatsClientTest"}; + instrumented_io_context &io_service_ = io_context_.GetIoService(); }; TEST_F(UsageStatsClientTest, TestRecordExtraUsageTag) { - gcs::UsageStatsClient usage_stats_client(*fake_kv_); + gcs::UsageStatsClient usage_stats_client(*fake_kv_, io_service_); usage_stats_client.RecordExtraUsageTag(usage::TagKey::_TEST1, "value1"); - fake_kv_->Get( - "usage_stats", "extra_usage_tag__test1", [](std::optional value) { - ASSERT_TRUE(value.has_value()); - ASSERT_EQ(value.value(), "value1"); - }); + fake_kv_->Get("usage_stats", + "extra_usage_tag__test1", + {[](std::optional value) { + ASSERT_TRUE(value.has_value()); + ASSERT_EQ(value.value(), "value1"); + }, + io_service_}); // Make sure the value is overriden for the same key. usage_stats_client.RecordExtraUsageTag(usage::TagKey::_TEST2, "value2"); - fake_kv_->Get( - "usage_stats", "extra_usage_tag__test2", [](std::optional value) { - ASSERT_TRUE(value.has_value()); - ASSERT_EQ(value.value(), "value2"); - }); + fake_kv_->Get("usage_stats", + "extra_usage_tag__test2", + {[](std::optional value) { + ASSERT_TRUE(value.has_value()); + ASSERT_EQ(value.value(), "value2"); + }, + io_service_}); } diff --git a/src/ray/gcs/store_client/in_memory_store_client.cc b/src/ray/gcs/store_client/in_memory_store_client.cc index 2138760d7e2d..06a6ef5b31c6 100644 --- a/src/ray/gcs/store_client/in_memory_store_client.cc +++ b/src/ray/gcs/store_client/in_memory_store_client.cc @@ -33,7 +33,7 @@ Status InMemoryStoreClient::AsyncPut(const std::string &table_name, table->records_[key] = data; inserted = true; } - callback.Post("GcsInMemoryStore.Put", inserted); + std::move(callback).Post("GcsInMemoryStore.Put", inserted); return Status::OK(); } @@ -48,7 +48,7 @@ Status InMemoryStoreClient::AsyncGet( if (iter != table->records_.end()) { data = iter->second; } - callback.Post("GcsInMemoryStore.Get", Status::OK(), std::move(data)); + std::move(callback).Post("GcsInMemoryStore.Get", Status::OK(), std::move(data)); return Status::OK(); } @@ -57,9 +57,10 @@ Status InMemoryStoreClient::AsyncGetAll( Postable)> callback) { auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); + absl::flat_hash_map result; result.reserve(table->records_.size()); result.insert(table->records_.begin(), table->records_.end()); - callback.Post("GcsInMemoryStore.GetAll", std::move(result)); + std::move(callback).Post("GcsInMemoryStore.GetAll", std::move(result)); return Status::OK(); } @@ -69,6 +70,8 @@ Status InMemoryStoreClient::AsyncMultiGet( Postable)> callback) { auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); + absl::flat_hash_map result; + result.reserve(keys.size()); for (const auto &key : keys) { auto it = table->records_.find(key); if (it == table->records_.end()) { @@ -76,7 +79,7 @@ Status InMemoryStoreClient::AsyncMultiGet( } result[key] = it->second; } - callback.Post("GcsInMemoryStore.GetAll", std::move(result)); + std::move(callback).Post("GcsInMemoryStore.GetAll", std::move(result)); return Status::OK(); } @@ -86,7 +89,7 @@ Status InMemoryStoreClient::AsyncDelete(const std::string &table_name, auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); auto num = table->records_.erase(key); - callback.Post("GcsInMemoryStore.Delete", num > 0); + std::move(callback).Post("GcsInMemoryStore.Delete", num > 0); return Status::OK(); } @@ -99,7 +102,7 @@ Status InMemoryStoreClient::AsyncBatchDelete(const std::string &table_name, for (auto &key : keys) { num += table->records_.erase(key); } - callback.Post("GcsInMemoryStore.BatchDelete", num); + std::move(callback).Post("GcsInMemoryStore.BatchDelete", num); return Status::OK(); } @@ -127,12 +130,14 @@ Status InMemoryStoreClient::AsyncGetKeys( Postable)> callback) { auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); + std::vector result; + result.reserve(table->records_.size()); for (const auto &[key, _] : table->records_) { if (key.find(prefix) == 0) { result.emplace_back(key); } } - callback.Post("GcsInMemoryStore.Keys", std::move(result)); + std::move(callback).Post("GcsInMemoryStore.Keys", std::move(result)); return Status::OK(); } @@ -142,7 +147,7 @@ Status InMemoryStoreClient::AsyncExists(const std::string &table_name, auto table = GetOrCreateTable(table_name); absl::MutexLock lock(&(table->mutex_)); bool result = table->records_.contains(key); - callback.Post("GcsInMemoryStore.Exists", result); + std::move(callback).Post("GcsInMemoryStore.Exists", result); return Status::OK(); } diff --git a/src/ray/gcs/store_client/redis_store_client.cc b/src/ray/gcs/store_client/redis_store_client.cc index e2961c43540d..503f39d936ab 100644 --- a/src/ray/gcs/store_client/redis_store_client.cc +++ b/src/ray/gcs/store_client/redis_store_client.cc @@ -81,6 +81,11 @@ void RedisStoreClient::MGetValues( auto total_count = batched_commands.size(); auto finished_count = std::make_shared(0); auto key_value_map = std::make_shared>(); + // `Postable` can only be invoked once, but here we have several Redis callbacks, the + // last of which will trigger the `callback`. So we need to use a shared `Postable`. + auto shared_callback = + std::make_shared)>>( + std::move(callback)); for (auto &command : batched_commands) { auto mget_callback = [finished_count, @@ -88,7 +93,7 @@ void RedisStoreClient::MGetValues( // Copies! args = command.args, // Copies! - callback, + shared_callback, key_value_map](const std::shared_ptr &reply) { if (!reply->IsNil()) { auto value = reply->ReadAsStringArray(); @@ -101,12 +106,12 @@ void RedisStoreClient::MGetValues( ++(*finished_count); if (*finished_count == total_count) { - callback.Dispatch("RedisStoreClient.AsyncMultiGet", std::move(*key_value_map)); + std::move(*shared_callback) + .Dispatch("RedisStoreClient.AsyncMultiGet", std::move(*key_value_map)); } }; SendRedisCmdArgsAsKeys(std::move(command), std::move(mget_callback)); } - return Status::OK(); } RedisStoreClient::RedisStoreClient(std::shared_ptr redis_client) @@ -126,9 +131,10 @@ Status RedisStoreClient::AsyncPut(const std::string &table_name, RedisKey{external_storage_namespace_, table_name}, /*args=*/{key, data}}; RedisCallback write_callback = - [callback = std::move(callback)](const std::shared_ptr &reply) { + [callback = + std::move(callback)](const std::shared_ptr &reply) mutable { auto added_num = reply->ReadAsInteger(); - callback.Dispatch("RedisStoreClient.AsyncPut", added_num != 0); + std::move(callback).Dispatch("RedisStoreClient.AsyncPut", added_num != 0); }; SendRedisCmdWithKeys({key}, std::move(command), std::move(write_callback)); return Status::OK(); @@ -139,14 +145,16 @@ Status RedisStoreClient::AsyncGet( const std::string &key, ToPostable> callback) { auto redis_callback = - [callback = std::move(callback)](const std::shared_ptr &reply) { + [callback = + std::move(callback)](const std::shared_ptr &reply) mutable { std::optional result; if (!reply->IsNil()) { result = reply->ReadAsString(); } RAY_CHECK(!reply->IsError()) << "Failed to get from Redis with status: " << reply->ReadAsStatus(); - callback.Dispatch("RedisStoreClient.AsyncGet", Status::OK(), std::move(result)); + std::move(callback).Dispatch( + "RedisStoreClient.AsyncGet", Status::OK(), std::move(result)); }; RedisCommand command{/*command=*/"HGET", @@ -179,7 +187,7 @@ Status RedisStoreClient::AsyncBatchDelete(const std::string &table_name, const std::vector &keys, Postable callback) { if (keys.empty()) { - callback.Dispatch("RedisStoreClient.AsyncBatchDelete", 0); + std::move(callback).Dispatch("RedisStoreClient.AsyncBatchDelete", 0); return Status::OK(); } return DeleteByKeys(table_name, keys, std::move(callback)); @@ -190,8 +198,8 @@ Status RedisStoreClient::AsyncMultiGet( const std::vector &keys, Postable)> callback) { if (keys.empty()) { - callback.Dispatch("RedisStoreClient.AsyncMultiGet", - absl::flat_hash_map{}); + std::move(callback).Dispatch("RedisStoreClient.AsyncMultiGet", + absl::flat_hash_map{}); return Status::OK(); } MGetValues(table_name, keys, std::move(callback)); @@ -321,14 +329,17 @@ Status RedisStoreClient::DeleteByKeys(const std::string &table, auto total_count = del_cmds.size(); auto finished_count = std::make_shared(0); auto num_deleted = std::make_shared(0); + auto shared_callback = std::make_shared>(std::move(callback)); + for (auto &command : del_cmds) { // `callback` is copied to each `delete_callback` lambda. Don't move. - auto delete_callback = [num_deleted, finished_count, total_count, callback]( + auto delete_callback = [num_deleted, finished_count, total_count, shared_callback]( const std::shared_ptr &reply) { (*num_deleted) += reply->ReadAsInteger(); ++(*finished_count); if (*finished_count == total_count) { - callback.Dispatch("RedisStoreClient.AsyncBatchDelete", *num_deleted); + std::move(*shared_callback) + .Dispatch("RedisStoreClient.AsyncBatchDelete", *num_deleted); } }; SendRedisCmdArgsAsKeys(std::move(command), std::move(delete_callback)); @@ -370,7 +381,8 @@ void RedisStoreClient::RedisScanner::Scan() { // we should consider using a reader-writer lock. absl::MutexLock lock(&mutex_); if (!cursor_.has_value()) { - callback_.Dispatch("RedisStoreClient.RedisScanner.Scan", std::move(results_)); + std::move(callback_).Dispatch("RedisStoreClient.RedisScanner.Scan", + std::move(results_)); self_ref_.reset(); return; } @@ -466,9 +478,10 @@ Status RedisStoreClient::AsyncExists(const std::string &table_name, "HEXISTS", RedisKey{external_storage_namespace_, table_name}, {key}}; SendRedisCmdArgsAsKeys( std::move(command), - [callback = std::move(callback)](const std::shared_ptr &reply) { + [callback = + std::move(callback)](const std::shared_ptr &reply) mutable { bool exists = reply->ReadAsInteger() > 0; - callback.Dispatch("RedisStoreClient.AsyncExists", exists); + std::move(callback).Dispatch("RedisStoreClient.AsyncExists", exists); }); return Status::OK(); } diff --git a/src/ray/gcs/store_client/test/in_memory_store_client_test.cc b/src/ray/gcs/store_client/test/in_memory_store_client_test.cc index 76feed6eb984..2a919ee18693 100644 --- a/src/ray/gcs/store_client/test/in_memory_store_client_test.cc +++ b/src/ray/gcs/store_client/test/in_memory_store_client_test.cc @@ -23,7 +23,7 @@ namespace gcs { class InMemoryStoreClientTest : public StoreClientTestBase { public: void InitStoreClient() override { - store_client_ = std::make_shared(*(io_service_pool_->Get())); + store_client_ = std::make_shared(); } void DisconnectStoreClient() override {} diff --git a/src/ray/gcs/store_client/test/observable_store_client_test.cc b/src/ray/gcs/store_client/test/observable_store_client_test.cc index d3d0f5c27368..b95d76b95595 100644 --- a/src/ray/gcs/store_client/test/observable_store_client_test.cc +++ b/src/ray/gcs/store_client/test/observable_store_client_test.cc @@ -24,8 +24,8 @@ namespace gcs { class ObservableStoreClientTest : public StoreClientTestBase { public: void InitStoreClient() override { - store_client_ = std::make_shared( - std::make_unique(*(io_service_pool_->Get()))); + store_client_ = + std::make_shared(std::make_unique()); } void DisconnectStoreClient() override {} diff --git a/src/ray/gcs/store_client/test/redis_store_client_test.cc b/src/ray/gcs/store_client/test/redis_store_client_test.cc index 1a881fbfc9c8..bfbba4b1809d 100644 --- a/src/ray/gcs/store_client/test/redis_store_client_test.cc +++ b/src/ray/gcs/store_client/test/redis_store_client_test.cc @@ -69,7 +69,7 @@ class RedisStoreClientTest : public StoreClientTestBase { void InitStoreClient() override { RedisClientOptions options("127.0.0.1", TEST_REDIS_SERVER_PORTS.front(), "", ""); redis_client_ = std::make_shared(options); - RAY_CHECK_OK(redis_client_->Connect(*io_service_pool_->Get())); + RAY_CHECK_OK(redis_client_->Connect(GetIoService())); store_client_ = std::make_shared(redis_client_); } @@ -99,10 +99,11 @@ TEST_F(RedisStoreClientTest, BasicSimple) { absl::StrCat("A", std::to_string(j)), std::to_string(i), true, - [i, cnt](auto r) { - --*cnt; - ASSERT_TRUE((i == 0 && r) || (i != 0 && !r)); - }) + {[i, cnt](auto r) { + --*cnt; + ASSERT_TRUE((i == 0 && r) || (i != 0 && !r)); + }, + GetIoService()}) .ok()); } } @@ -111,11 +112,12 @@ TEST_F(RedisStoreClientTest, BasicSimple) { ASSERT_TRUE(store_client_ ->AsyncGet("T", absl::StrCat("A", std::to_string(j)), - [cnt](auto s, auto r) { - --*cnt; - ASSERT_TRUE(r.has_value()); - ASSERT_EQ(*r, "99"); - }) + {[cnt](auto s, auto r) { + --*cnt; + ASSERT_TRUE(r.has_value()); + ASSERT_EQ(*r, "99"); + }, + GetIoService()}) .ok()); } ASSERT_TRUE(WaitForCondition([cnt]() { return *cnt == 0; }, 5000)); @@ -136,12 +138,13 @@ TEST_F(RedisStoreClientTest, Complicated) { "P_" + std::to_string(j), std::to_string(j), true, - [&finished, j](auto r) mutable { - RAY_LOG(INFO) - << "F AsyncPut: " << ("P_" + std::to_string(j)); - ++finished; - ASSERT_TRUE(r); - }) + {[&finished, j](auto r) mutable { + RAY_LOG(INFO) + << "F AsyncPut: " << ("P_" + std::to_string(j)); + ++finished; + ASSERT_TRUE(r); + }, + GetIoService()}) .ok()); keys.push_back(std::to_string(j)); } @@ -163,77 +166,83 @@ TEST_F(RedisStoreClientTest, Complicated) { ->AsyncMultiGet( "N", p_keys, - [&finished, i, keys, window, &sent, p_keys, n_keys, this]( - auto m) mutable { - RAY_LOG(INFO) << "F SendAsyncMultiGet: " << absl::StrJoin(p_keys, ","); - ++finished; - ASSERT_EQ(keys.size(), m.size()); - for (auto &key : keys) { - ASSERT_EQ(m["P_" + key], key); - } - - if ((i / window) % 2 == 0) { - // Delete non exist keys - for (size_t i = 0; i < keys.size(); ++i) { - ++sent; - RAY_LOG(INFO) << "S AsyncDelete: " << n_keys[i]; - ASSERT_TRUE( - store_client_ - ->AsyncDelete("N", - n_keys[i], - [&finished, n_keys, i](auto b) mutable { - RAY_LOG(INFO) - << "F AsyncDelete: " << n_keys[i]; - ++finished; - ASSERT_FALSE(b); - }) - .ok()); - - ++sent; - RAY_LOG(INFO) << "S AsyncExists: " << p_keys[i]; - ASSERT_TRUE( - store_client_ - ->AsyncExists("N", - p_keys[i], - [&finished, p_keys, i](auto b) mutable { - RAY_LOG(INFO) - << "F AsyncExists: " << p_keys[i]; - ++finished; - ASSERT_TRUE(b); - }) - .ok()); - } - } else { - ++sent; - RAY_LOG(INFO) << "S AsyncBatchDelete: " << absl::StrJoin(p_keys, ","); - ASSERT_TRUE(store_client_ - ->AsyncBatchDelete( - "N", - p_keys, - [&finished, p_keys, keys](auto n) mutable { - RAY_LOG(INFO) << "F AsyncBatchDelete: " - << absl::StrJoin(p_keys, ","); - ++finished; - ASSERT_EQ(n, keys.size()); - }) - .ok()); - - for (auto p_key : p_keys) { - ++sent; - RAY_LOG(INFO) << "S AsyncExists: " << p_key; - ASSERT_TRUE(store_client_ - ->AsyncExists("N", - p_key, - [&finished, p_key](auto b) mutable { - RAY_LOG(INFO) - << "F AsyncExists: " << p_key; - ++finished; - ASSERT_FALSE(false); - }) - .ok()); - } - } - }) + {[&finished, i, keys, window, &sent, p_keys, n_keys, this]( + auto m) mutable { + RAY_LOG(INFO) << "F SendAsyncMultiGet: " << absl::StrJoin(p_keys, ","); + ++finished; + ASSERT_EQ(keys.size(), m.size()); + for (auto &key : keys) { + ASSERT_EQ(m["P_" + key], key); + } + + if ((i / window) % 2 == 0) { + // Delete non exist keys + for (size_t i = 0; i < keys.size(); ++i) { + ++sent; + RAY_LOG(INFO) << "S AsyncDelete: " << n_keys[i]; + ASSERT_TRUE( + store_client_ + ->AsyncDelete("N", + n_keys[i], + {[&finished, n_keys, i](auto b) mutable { + RAY_LOG(INFO) + << "F AsyncDelete: " << n_keys[i]; + ++finished; + ASSERT_FALSE(b); + }, + GetIoService()}) + .ok()); + + ++sent; + RAY_LOG(INFO) << "S AsyncExists: " << p_keys[i]; + ASSERT_TRUE( + store_client_ + ->AsyncExists("N", + p_keys[i], + {[&finished, p_keys, i](auto b) mutable { + RAY_LOG(INFO) + << "F AsyncExists: " << p_keys[i]; + ++finished; + ASSERT_TRUE(b); + }, + GetIoService()}) + .ok()); + } + } else { + ++sent; + RAY_LOG(INFO) + << "S AsyncBatchDelete: " << absl::StrJoin(p_keys, ","); + ASSERT_TRUE(store_client_ + ->AsyncBatchDelete( + "N", + p_keys, + {[&finished, p_keys, keys](auto n) mutable { + RAY_LOG(INFO) << "F AsyncBatchDelete: " + << absl::StrJoin(p_keys, ","); + ++finished; + ASSERT_EQ(n, keys.size()); + }, + GetIoService()}) + .ok()); + + for (auto p_key : p_keys) { + ++sent; + RAY_LOG(INFO) << "S AsyncExists: " << p_key; + ASSERT_TRUE(store_client_ + ->AsyncExists("N", + p_key, + {[&finished, p_key](auto b) mutable { + RAY_LOG(INFO) + << "F AsyncExists: " << p_key; + ++finished; + ASSERT_FALSE(false); + }, + GetIoService()}) + .ok()); + } + } + }, + GetIoService()}) .ok()); } ASSERT_TRUE(WaitForCondition( @@ -268,12 +277,16 @@ TEST_F(RedisStoreClientTest, Random) { } RAY_LOG(INFO) << "m_multi_get Sending: " << idx; *counter += 1; - RAY_CHECK_OK( - store_client_->AsyncMultiGet("N", keys, [result, idx, counter](auto m) mutable { - RAY_LOG(INFO) << "m_multi_get Finished: " << idx << " " << m.size(); - *counter -= 1; - ASSERT_TRUE(m == result); - })); + RAY_CHECK_OK(store_client_->AsyncMultiGet("N", + keys, + {[result, idx, counter](auto m) mutable { + RAY_LOG(INFO) + << "m_multi_get Finished: " << idx + << " " << m.size(); + *counter -= 1; + ASSERT_TRUE(m == result); + }, + GetIoService()})); }; auto m_batch_delete = [&, counter, this](size_t idx) mutable { @@ -285,11 +298,14 @@ TEST_F(RedisStoreClientTest, Random) { RAY_LOG(INFO) << "m_batch_delete Sending: " << idx; *counter += 1; RAY_CHECK_OK(store_client_->AsyncBatchDelete( - "N", keys, [&counter, deleted_num, idx](auto v) mutable { - RAY_LOG(INFO) << "m_batch_delete Finished: " << idx << " " << v; - *counter -= 1; - ASSERT_EQ(v, deleted_num); - })); + "N", + keys, + {[counter, deleted_num, idx](auto v) mutable { + RAY_LOG(INFO) << "m_batch_delete Finished: " << idx << " " << v; + *counter -= 1; + ASSERT_EQ(v, deleted_num); + }, + GetIoService()})); }; auto m_delete = [&, this](size_t idx) mutable { @@ -297,11 +313,16 @@ TEST_F(RedisStoreClientTest, Random) { bool deleted = dict.erase(k) > 0; RAY_LOG(INFO) << "m_delete Sending: " << idx << " " << k; *counter += 1; - RAY_CHECK_OK(store_client_->AsyncDelete("N", k, [counter, k, idx, deleted](auto r) { - RAY_LOG(INFO) << "m_delete Finished: " << idx << " " << k << " " << deleted; - *counter -= 1; - ASSERT_EQ(deleted, r); - })); + RAY_CHECK_OK(store_client_->AsyncDelete("N", + k, + {[counter, k, idx, deleted](auto r) { + RAY_LOG(INFO) + << "m_delete Finished: " << idx << " " + << k << " " << deleted; + *counter -= 1; + ASSERT_EQ(deleted, r); + }, + GetIoService()})); }; auto m_get = [&, counter, this](size_t idx) { @@ -312,11 +333,16 @@ TEST_F(RedisStoreClientTest, Random) { } RAY_LOG(INFO) << "m_get Sending: " << idx; *counter += 1; - RAY_CHECK_OK(store_client_->AsyncGet("N", k, [counter, idx, v](auto, auto r) { - RAY_LOG(INFO) << "m_get Finished: " << idx << " " << (r ? *r : std::string("-")); - *counter -= 1; - ASSERT_EQ(v, r); - })); + RAY_CHECK_OK(store_client_->AsyncGet("N", + k, + {[counter, idx, v](auto, auto r) { + RAY_LOG(INFO) + << "m_get Finished: " << idx << " " + << (r ? *r : std::string("-")); + *counter -= 1; + ASSERT_EQ(v, r); + }, + GetIoService()})); }; auto m_exists = [&, counter, this](size_t idx) { @@ -324,12 +350,15 @@ TEST_F(RedisStoreClientTest, Random) { bool existed = dict.count(k); RAY_LOG(INFO) << "m_exists Sending: " << idx; *counter += 1; - RAY_CHECK_OK( - store_client_->AsyncExists("N", k, [k, existed, counter, idx](auto r) mutable { - RAY_LOG(INFO) << "m_exists Finished: " << idx << " " << k << " " << r; - *counter -= 1; - ASSERT_EQ(existed, r) << " exists check " << k; - })); + RAY_CHECK_OK(store_client_->AsyncExists( + "N", + k, + {[k, existed, counter, idx](auto r) mutable { + RAY_LOG(INFO) << "m_exists Finished: " << idx << " " << k << " " << r; + *counter -= 1; + ASSERT_EQ(existed, r) << " exists check " << k; + }, + GetIoService()})); }; auto m_puts = [&, counter, this](size_t idx) mutable { @@ -342,13 +371,17 @@ TEST_F(RedisStoreClientTest, Random) { dict[k] = v; RAY_LOG(INFO) << "m_put Sending: " << idx << " " << k << " " << v; *counter += 1; - RAY_CHECK_OK(store_client_->AsyncPut( - "N", k, v, true, [idx, added, k, counter](bool r) mutable { - RAY_LOG(INFO) << "m_put Finished: " - << " " << idx << " " << k << " " << r; - *counter -= 1; - ASSERT_EQ(r, added); - })); + RAY_CHECK_OK(store_client_->AsyncPut("N", + k, + v, + true, + {[idx, added, k, counter](bool r) mutable { + RAY_LOG(INFO) << "m_put Finished: " << " " + << idx << " " << k << " " << r; + *counter -= 1; + ASSERT_EQ(r, added); + }, + GetIoService()})); }; std::vector> ops{ diff --git a/src/ray/gcs/store_client/test/store_client_test_base.h b/src/ray/gcs/store_client/test/store_client_test_base.h index ca373604e30e..ebc2fd184edc 100644 --- a/src/ray/gcs/store_client/test/store_client_test_base.h +++ b/src/ray/gcs/store_client/test/store_client_test_base.h @@ -59,27 +59,37 @@ class StoreClientTestBase : public ::testing::Test { virtual void DisconnectStoreClient() = 0; + auto &GetIoService() { return *io_service_pool_->Get(); } + protected: void Put() { - auto put_calllback = [this](auto) { --pending_count_; }; + auto put_callback = [this](auto) { --pending_count_; }; for (const auto &[key, value] : key_to_value_) { ++pending_count_; - RAY_CHECK_OK(store_client_->AsyncPut( - table_name_, key.Hex(), value.SerializeAsString(), true, put_calllback)); + RAY_CHECK_OK(store_client_->AsyncPut(table_name_, + key.Hex(), + value.SerializeAsString(), + true, + {put_callback, GetIoService()})); // Make sure no-op callback is handled well - RAY_CHECK_OK(store_client_->AsyncPut( - table_name_, key.Hex(), value.SerializeAsString(), true, nullptr)); + RAY_CHECK_OK(store_client_->AsyncPut(table_name_, + key.Hex(), + value.SerializeAsString(), + true, + {[](auto) {}, GetIoService()})); } WaitPendingDone(); } void Delete() { - auto delete_calllback = [this](auto) { --pending_count_; }; + auto delete_callback = [this](auto) { --pending_count_; }; for (const auto &[key, _] : key_to_value_) { ++pending_count_; - RAY_CHECK_OK(store_client_->AsyncDelete(table_name_, key.Hex(), delete_calllback)); + RAY_CHECK_OK(store_client_->AsyncDelete( + table_name_, key.Hex(), {delete_callback, GetIoService()})); // Make sure no-op callback is handled well - RAY_CHECK_OK(store_client_->AsyncDelete(table_name_, key.Hex(), nullptr)); + RAY_CHECK_OK(store_client_->AsyncDelete( + table_name_, key.Hex(), {[](auto) {}, GetIoService()})); } WaitPendingDone(); } @@ -98,7 +108,8 @@ class StoreClientTestBase : public ::testing::Test { }; for (const auto &[key, _] : key_to_value_) { ++pending_count_; - RAY_CHECK_OK(store_client_->AsyncGet(table_name_, key.Hex(), get_callback)); + RAY_CHECK_OK(store_client_->AsyncGet( + table_name_, key.Hex(), {get_callback, GetIoService()})); } WaitPendingDone(); } @@ -114,7 +125,8 @@ class StoreClientTestBase : public ::testing::Test { }; ++pending_count_; - RAY_CHECK_OK(store_client_->AsyncGet(table_name_, key, get_callback)); + RAY_CHECK_OK( + store_client_->AsyncGet(table_name_, key, {get_callback, GetIoService()})); } WaitPendingDone(); } @@ -138,7 +150,8 @@ class StoreClientTestBase : public ::testing::Test { }; pending_count_ += key_to_value_.size(); - RAY_CHECK_OK(store_client_->AsyncGetAll(table_name_, get_all_callback)); + RAY_CHECK_OK( + store_client_->AsyncGetAll(table_name_, {get_all_callback, GetIoService()})); WaitPendingDone(); } @@ -164,7 +177,8 @@ class StoreClientTestBase : public ::testing::Test { pending_count_ += result_set.size(); - RAY_CHECK_OK(store_client_->AsyncGetKeys(table_name_, prefix, get_keys_callback)); + RAY_CHECK_OK(store_client_->AsyncGetKeys( + table_name_, prefix, {get_keys_callback, GetIoService()})); WaitPendingDone(); } } @@ -177,22 +191,24 @@ class StoreClientTestBase : public ::testing::Test { pending_count_ += key_to_value_.size(); for (const auto &item : key_to_value_) { - RAY_CHECK_OK( - store_client_->AsyncExists(table_name_, item.first.Hex(), exists_callback)); + RAY_CHECK_OK(store_client_->AsyncExists( + table_name_, item.first.Hex(), {exists_callback, GetIoService()})); } WaitPendingDone(); } void BatchDelete() { - auto delete_calllback = [this](auto) { --pending_count_; }; + auto delete_callback = [this](auto) { --pending_count_; }; ++pending_count_; std::vector keys; for (auto &[key, _] : key_to_value_) { keys.push_back(key.Hex()); } - RAY_CHECK_OK(store_client_->AsyncBatchDelete(table_name_, keys, delete_calllback)); + RAY_CHECK_OK(store_client_->AsyncBatchDelete( + table_name_, keys, {delete_callback, GetIoService()})); // Make sure no-op callback is handled well - RAY_CHECK_OK(store_client_->AsyncBatchDelete(table_name_, keys, nullptr)); + RAY_CHECK_OK(store_client_->AsyncBatchDelete( + table_name_, keys, {[](auto) {}, GetIoService()})); WaitPendingDone(); } From 418d9d3abd2253f95088b8aaec419a212e6fe3fb Mon Sep 17 00:00:00 2001 From: Ruiyang Wang Date: Mon, 2 Dec 2024 23:23:09 -0800 Subject: [PATCH 26/26] lint Signed-off-by: Ruiyang Wang --- .../gcs_server/test/gcs_server_test_util.h | 4 ++-- .../gcs/store_client/redis_store_client.cc | 23 +++++++++---------- .../test/redis_store_client_test.cc | 5 ++-- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h index 2d8294bd9f06..7245fed4265a 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h +++ b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h @@ -309,7 +309,7 @@ struct GcsServerMocker { void ShutdownRaylet( const NodeID &node_id, bool graceful, - const rpc::ClientCallback &callback) override {}; + const rpc::ClientCallback &callback) override{}; void DrainRaylet( const rpc::autoscaler::DrainNodeReason &reason, @@ -322,7 +322,7 @@ struct GcsServerMocker { }; void NotifyGCSRestart( - const rpc::ClientCallback &callback) override {}; + const rpc::ClientCallback &callback) override{}; ~MockRayletClient() {} diff --git a/src/ray/gcs/store_client/redis_store_client.cc b/src/ray/gcs/store_client/redis_store_client.cc index 503f39d936ab..28ca2a4ae0d9 100644 --- a/src/ray/gcs/store_client/redis_store_client.cc +++ b/src/ray/gcs/store_client/redis_store_client.cc @@ -144,18 +144,17 @@ Status RedisStoreClient::AsyncGet( const std::string &table_name, const std::string &key, ToPostable> callback) { - auto redis_callback = - [callback = - std::move(callback)](const std::shared_ptr &reply) mutable { - std::optional result; - if (!reply->IsNil()) { - result = reply->ReadAsString(); - } - RAY_CHECK(!reply->IsError()) - << "Failed to get from Redis with status: " << reply->ReadAsStatus(); - std::move(callback).Dispatch( - "RedisStoreClient.AsyncGet", Status::OK(), std::move(result)); - }; + auto redis_callback = [callback = std::move(callback)]( + const std::shared_ptr &reply) mutable { + std::optional result; + if (!reply->IsNil()) { + result = reply->ReadAsString(); + } + RAY_CHECK(!reply->IsError()) + << "Failed to get from Redis with status: " << reply->ReadAsStatus(); + std::move(callback).Dispatch( + "RedisStoreClient.AsyncGet", Status::OK(), std::move(result)); + }; RedisCommand command{/*command=*/"HGET", RedisKey{external_storage_namespace_, table_name}, diff --git a/src/ray/gcs/store_client/test/redis_store_client_test.cc b/src/ray/gcs/store_client/test/redis_store_client_test.cc index bfbba4b1809d..4e82dd941065 100644 --- a/src/ray/gcs/store_client/test/redis_store_client_test.cc +++ b/src/ray/gcs/store_client/test/redis_store_client_test.cc @@ -376,8 +376,9 @@ TEST_F(RedisStoreClientTest, Random) { v, true, {[idx, added, k, counter](bool r) mutable { - RAY_LOG(INFO) << "m_put Finished: " << " " - << idx << " " << k << " " << r; + RAY_LOG(INFO) + << "m_put Finished: " + << " " << idx << " " << k << " " << r; *counter -= 1; ASSERT_EQ(r, added); },