Skip to content

Commit

Permalink
[Core] Retryable grpc client (ray-project#47981)
Browse files Browse the repository at this point in the history
Signed-off-by: Jiajun Yao <[email protected]>
Signed-off-by: Connor Sanders <[email protected]>
  • Loading branch information
jjyao authored and jecsand838 committed Dec 4, 2024
1 parent 35d6a43 commit 11152e6
Show file tree
Hide file tree
Showing 26 changed files with 772 additions and 311 deletions.
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,14 @@ ray_cc_library(
"src/ray/rpc/grpc_server.cc",
"src/ray/rpc/server_call.cc",
"src/ray/rpc/rpc_chaos.cc",
"src/ray/rpc/retryable_grpc_client.cc",
],
hdrs = glob([
"src/ray/rpc/rpc_chaos.h",
"src/ray/rpc/client_call.h",
"src/ray/rpc/common.h",
"src/ray/rpc/grpc_client.h",
"src/ray/rpc/retryable_grpc_client.h",
"src/ray/rpc/grpc_server.h",
"src/ray/rpc/metrics_agent_client.h",
"src/ray/rpc/server_call.h",
Expand Down
42 changes: 40 additions & 2 deletions python/ray/tests/test_streaming_generator_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import sys
import time
import gc
import os
import signal
import random
import asyncio
from typing import Optional
Expand Down Expand Up @@ -34,6 +36,40 @@ def assert_no_leak():
assert core_worker.get_memory_store_size() == 0


@pytest.mark.skipif(
sys.platform == "win32", reason="SIGKILL is not available on Windows"
)
def test_caller_death(monkeypatch, shutdown_only):
"""
Test the case where caller of a streaming generator actor task dies
while the streaming generator task is executing. The streaming
generator task should still finish and won't block other actor tasks.
This means that `ReportGeneratorItemReturns` RPC should fail and it shouldn't
be retried indefinitely.
"""
monkeypatch.setenv("RAY_core_worker_rpc_server_reconnect_timeout_s", "1")
ray.init()

@ray.remote
class Callee:
def gen(self, caller_pid):
os.kill(caller_pid, signal.SIGKILL)
yield [1] * 1024 * 1024

def ping(self):
pass

@ray.remote
def caller(callee):
ray.get(callee.gen.remote(os.getpid()))

callee = Callee.remote()
o = caller.remote(callee)
ray.wait([o])
# Make sure gen will finish and ping can run.
ray.get(callee.ping.remote())


@pytest.mark.parametrize("backpressure", [False, True])
@pytest.mark.parametrize("delay_latency", [0.1, 1])
@pytest.mark.parametrize("threshold", [1, 3])
Expand All @@ -54,6 +90,10 @@ def test_ray_datasetlike_mini_stress_test(
"RAY_testing_asio_delay_us",
"CoreWorkerService.grpc_server.ReportGeneratorItemReturns=10000:1000000",
)
m.setenv(
"RAY_testing_rpc_failure",
"CoreWorkerService.grpc_client.ReportGeneratorItemReturns=5",
)
cluster = ray_start_cluster
cluster.add_node(
num_cpus=1,
Expand Down Expand Up @@ -261,8 +301,6 @@ async def async_stream(self, signal):


if __name__ == "__main__":
import os

if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
Expand Down
5 changes: 5 additions & 0 deletions src/mock/ray/raylet_client/raylet_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ class MockRayletClientInterface : public RayletClientInterface {
int64_t draining_deadline_timestamp_ms,
const rpc::ClientCallback<rpc::DrainRayletReply> &callback),
(override));
MOCK_METHOD(void,
IsLocalWorkerDead,
(const WorkerID &worker_id,
const rpc::ClientCallback<rpc::IsLocalWorkerDeadReply> &callback),
(override));
};

} // namespace ray
5 changes: 4 additions & 1 deletion src/ray/common/ray_config_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ RAY_CONFIG(int32_t, gcs_grpc_initial_reconnect_backoff_ms, 100)
RAY_CONFIG(uint64_t, gcs_grpc_max_request_queued_max_bytes, 1024UL * 1024 * 1024 * 5)

/// The duration between two checks for grpc status.
RAY_CONFIG(int32_t, gcs_client_check_connection_status_interval_milliseconds, 1000)
RAY_CONFIG(int32_t, grpc_client_check_connection_status_interval_milliseconds, 1000)

/// Due to the protocol drawback, raylet needs to refresh the message if
/// no message is received for a while.
Expand Down Expand Up @@ -693,6 +693,9 @@ RAY_CONFIG(int64_t, timeout_ms_task_wait_for_death_info, 1000)
/// report the loads to raylet.
RAY_CONFIG(int64_t, core_worker_internal_heartbeat_ms, 1000)

/// Timeout for core worker grpc server reconnection in seconds.
RAY_CONFIG(int32_t, core_worker_rpc_server_reconnect_timeout_s, 60)

/// Maximum amount of memory that will be used by running tasks' args.
RAY_CONFIG(float, max_task_args_memory_fraction, 0.7)

Expand Down
53 changes: 41 additions & 12 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,38 @@ CoreWorker::CoreWorker(CoreWorkerOptions options, const WorkerID &worker_id)
}

core_worker_client_pool_ =
std::make_shared<rpc::CoreWorkerClientPool>(*client_call_manager_);
std::make_shared<rpc::CoreWorkerClientPool>([&](const rpc::Address &addr) {
return std::make_shared<rpc::CoreWorkerClient>(
addr,
*client_call_manager_,
/*core_worker_unavailable_timeout_callback=*/[this, addr]() {
const NodeID node_id = NodeID::FromBinary(addr.raylet_id());
const WorkerID worker_id = WorkerID::FromBinary(addr.worker_id());
const rpc::GcsNodeInfo *node_info =
gcs_client_->Nodes().Get(node_id, /*filter_dead_nodes=*/false);
if (node_info != nullptr && node_info->state() == rpc::GcsNodeInfo::DEAD) {
RAY_LOG(INFO).WithField(worker_id).WithField(node_id)
<< "Disconnect core worker client since its node is dead";
core_worker_client_pool_->Disconnect(worker_id);
return;
}

raylet::RayletClient raylet_client(
rpc::NodeManagerWorkerClient::make(node_info->node_manager_address(),
node_info->node_manager_port(),
*client_call_manager_));
raylet_client.IsLocalWorkerDead(
worker_id,
[this, worker_id](const Status &status,
rpc::IsLocalWorkerDeadReply &&reply) {
if (status.ok() && reply.is_dead()) {
RAY_LOG(INFO).WithField(worker_id)
<< "Disconnect core worker client since it is dead";
core_worker_client_pool_->Disconnect(worker_id);
}
});
});
});

object_info_publisher_ = std::make_unique<pubsub::Publisher>(
/*channels=*/std::vector<
Expand Down Expand Up @@ -866,14 +897,6 @@ void CoreWorker::Shutdown() {

task_event_buffer_->Stop();

if (gcs_client_) {
// We should disconnect gcs client first otherwise because it contains
// a blocking logic that can block the io service upon
// gcs shutdown.
// TODO(sang): Refactor GCS client to be more robust.
RAY_LOG(INFO) << "Disconnecting a GCS client.";
gcs_client_->Disconnect();
}
io_service_.stop();
RAY_LOG(INFO) << "Waiting for joining a core worker io thread. If it hangs here, there "
"might be deadlock or a high load in the core worker io service.";
Expand All @@ -886,7 +909,13 @@ void CoreWorker::Shutdown() {

// Now that gcs_client is not used within io service, we can reset the pointer and clean
// it up.
gcs_client_.reset();
if (gcs_client_) {
RAY_LOG(INFO) << "Disconnecting a GCS client.";
// TODO(hjiang): Move the Disconnect() logic
// to GcsClient destructor.
gcs_client_->Disconnect();
gcs_client_.reset();
}

RAY_LOG(INFO) << "Core worker ready to be deallocated.";
}
Expand Down Expand Up @@ -3454,13 +3483,13 @@ Status CoreWorker::ReportGeneratorItemReturns(
if (status.ok()) {
num_objects_consumed = reply.total_num_object_consumed();
} else {
// TODO(sang): Handle network error more gracefully.
// If the request fails, we should just resume until task finishes without
// backpressure.
num_objects_consumed = waiter->TotalObjectGenerated();
RAY_LOG(WARNING).WithField(return_id)
<< "Failed to report streaming generator return "
"to the caller. The yield'ed ObjectRef may not be usable.";
"to the caller. The yield'ed ObjectRef may not be usable. "
<< status;
}
waiter->HandleObjectReported(num_objects_consumed);
});
Expand Down
3 changes: 1 addition & 2 deletions src/ray/gcs/gcs_client/gcs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ Status GcsClient::FetchClusterId(int64_t timeout_ms) {
Status s = gcs_rpc_client_->SyncGetClusterId(request, &reply, timeout_ms);
if (!s.ok()) {
RAY_LOG(WARNING) << "Failed to get cluster ID from GCS server: " << s;
gcs_rpc_client_->Shutdown();
gcs_rpc_client_.reset();
client_call_manager_.reset();
return s;
Expand All @@ -189,7 +188,7 @@ Status GcsClient::FetchClusterId(int64_t timeout_ms) {

void GcsClient::Disconnect() {
if (gcs_rpc_client_) {
gcs_rpc_client_->Shutdown();
gcs_rpc_client_.reset();
}
}

Expand Down
72 changes: 43 additions & 29 deletions src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,9 @@ void GcsServer::InitClusterTaskManager() {

void GcsServer::InitGcsJobManager(const GcsInitData &gcs_init_data) {
auto client_factory = [this](const rpc::Address &address) {
return std::make_shared<rpc::CoreWorkerClient>(address, client_call_manager_);
return std::make_shared<rpc::CoreWorkerClient>(address, client_call_manager_, []() {
RAY_LOG(FATAL) << "GCS doesn't call any retryable core worker grpc methods.";
});
};
RAY_CHECK(gcs_table_storage_ && gcs_publisher_);
gcs_job_manager_ = std::make_unique<GcsJobManager>(gcs_table_storage_,
Expand Down Expand Up @@ -449,34 +451,46 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) {
};

RAY_CHECK(gcs_resource_manager_ && cluster_task_manager_);
scheduler = std::make_unique<GcsActorScheduler>(
io_context_provider_.GetDefaultIOContext(),
gcs_table_storage_->ActorTable(),
*gcs_node_manager_,
*cluster_task_manager_,
schedule_failure_handler,
schedule_success_handler,
*raylet_client_pool_,
/*factory=*/
[this](const rpc::Address &address) {
return std::make_shared<rpc::CoreWorkerClient>(address, client_call_manager_);
},
/*normal_task_resources_changed_callback=*/
[this](const NodeID &node_id, const rpc::ResourcesData &resources) {
gcs_resource_manager_->UpdateNodeNormalTaskResources(node_id, resources);
});
gcs_actor_manager_ = std::make_unique<GcsActorManager>(
std::move(scheduler),
gcs_table_storage_,
gcs_publisher_,
*runtime_env_manager_,
*function_manager_,
[this](const ActorID &actor_id) {
gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenActorDead(actor_id);
},
[this](const rpc::Address &address) {
return std::make_shared<rpc::CoreWorkerClient>(address, client_call_manager_);
});
scheduler =
std::make_unique<GcsActorScheduler>(
io_context_provider_.GetDefaultIOContext(),
gcs_table_storage_->ActorTable(),
*gcs_node_manager_,
*cluster_task_manager_,
schedule_failure_handler,
schedule_success_handler,
*raylet_client_pool_,
/*factory=*/
[this](const rpc::Address &address) {
return std::make_shared<rpc::CoreWorkerClient>(
address, client_call_manager_, []() {
RAY_LOG(FATAL)
<< "GCS doesn't call any retryable core worker grpc methods.";
});
},
/*normal_task_resources_changed_callback=*/
[this](const NodeID &node_id, const rpc::ResourcesData &resources) {
gcs_resource_manager_->UpdateNodeNormalTaskResources(node_id, resources);
});

gcs_actor_manager_ =
std::make_unique<GcsActorManager>(
std::move(scheduler),
gcs_table_storage_,
gcs_publisher_,
*runtime_env_manager_,
*function_manager_,
[this](const ActorID &actor_id) {
gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenActorDead(
actor_id);
},
[this](const rpc::Address &address) {
return std::make_shared<rpc::CoreWorkerClient>(
address, client_call_manager_, []() {
RAY_LOG(FATAL)
<< "GCS doesn't call any retryable core worker grpc methods.";
});
});

// Initialize by gcs tables data.
gcs_actor_manager_->Initialize(gcs_init_data);
Expand Down
4 changes: 4 additions & 0 deletions src/ray/gcs/gcs_server/test/gcs_server_test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,10 @@ struct GcsServerMocker {
drain_raylet_callbacks.push_back(callback);
};

void IsLocalWorkerDead(
const WorkerID &worker_id,
const rpc::ClientCallback<rpc::IsLocalWorkerDeadReply> &callback) override{};

void NotifyGCSRestart(
const rpc::ClientCallback<rpc::NotifyGCSRestartReply> &callback) override{};

Expand Down
11 changes: 11 additions & 0 deletions src/ray/protobuf/node_manager.proto
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,16 @@ message PushMutableObjectReply {
bool done = 1;
}

message IsLocalWorkerDeadRequest {
// Binary worker id of the target worker.
bytes worker_id = 1;
}

message IsLocalWorkerDeadReply {
// Whether the target worker is dead or not.
bool is_dead = 1;
}

// Service for inter-node-manager communication.
service NodeManagerService {
// Handle the case when GCS restarted.
Expand Down Expand Up @@ -457,4 +467,5 @@ service NodeManagerService {
rpc RegisterMutableObject(RegisterMutableObjectRequest)
returns (RegisterMutableObjectReply);
rpc PushMutableObject(PushMutableObjectRequest) returns (PushMutableObjectReply);
rpc IsLocalWorkerDead(IsLocalWorkerDeadRequest) returns (IsLocalWorkerDeadReply);
}
14 changes: 13 additions & 1 deletion src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,11 @@ NodeManager::NodeManager(
config.ray_debugger_external,
/*get_time=*/[]() { return absl::Now(); }),
client_call_manager_(io_service),
worker_rpc_pool_(client_call_manager_),
worker_rpc_pool_([this](const rpc::Address &addr) {
return std::make_shared<rpc::CoreWorkerClient>(addr, client_call_manager_, []() {
RAY_LOG(FATAL) << "Raylet doesn't call any retryable core worker grpc methods.";
});
}),
core_worker_subscriber_(std::make_unique<pubsub::Subscriber>(
self_node_id_,
/*channels=*/
Expand Down Expand Up @@ -2027,6 +2031,14 @@ void NodeManager::HandleReturnWorker(rpc::ReturnWorkerRequest request,
send_reply_callback(status, nullptr, nullptr);
}

void NodeManager::HandleIsLocalWorkerDead(rpc::IsLocalWorkerDeadRequest request,
rpc::IsLocalWorkerDeadReply *reply,
rpc::SendReplyCallback send_reply_callback) {
reply->set_is_dead(worker_pool_.GetRegisteredWorker(
WorkerID::FromBinary(request.worker_id())) == nullptr);
send_reply_callback(Status::OK(), /*success=*/nullptr, /*failure=*/nullptr);
}

void NodeManager::HandleDrainRaylet(rpc::DrainRayletRequest request,
rpc::DrainRayletReply *reply,
rpc::SendReplyCallback send_reply_callback) {
Expand Down
4 changes: 4 additions & 0 deletions src/ray/raylet/node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
rpc::DrainRayletReply *reply,
rpc::SendReplyCallback send_reply_callback) override;

void HandleIsLocalWorkerDead(rpc::IsLocalWorkerDeadRequest request,
rpc::IsLocalWorkerDeadReply *reply,
rpc::SendReplyCallback send_reply_callback) override;

/// Handle a `CancelWorkerLease` request.
void HandleCancelWorkerLease(rpc::CancelWorkerLeaseRequest request,
rpc::CancelWorkerLeaseReply *reply,
Expand Down
4 changes: 3 additions & 1 deletion src/ray/raylet/worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ void Worker::Connect(int port) {
rpc::Address addr;
addr.set_ip_address(ip_address_);
addr.set_port(port_);
rpc_client_ = std::make_unique<rpc::CoreWorkerClient>(addr, client_call_manager_);
rpc_client_ = std::make_unique<rpc::CoreWorkerClient>(addr, client_call_manager_, []() {
RAY_LOG(FATAL) << "Raylet doesn't call any retryable core worker grpc methods.";
});
Connect(rpc_client_);
}

Expand Down
Loading

0 comments on commit 11152e6

Please sign in to comment.