From 3e6c0b48b66a71f4714ea58369c8f68ea2758ffa Mon Sep 17 00:00:00 2001 From: Javier Date: Thu, 12 Dec 2024 09:51:00 +0000 Subject: [PATCH] refactor(framework) Move gRPC stub wrapping with `RetryInvoker` to `flwr.common` (#4636) --- src/py/flwr/common/retry_invoker.py | 67 ++++++++++++++++++++++++ src/py/flwr/server/driver/grpc_driver.py | 65 ++--------------------- 2 files changed, 71 insertions(+), 61 deletions(-) diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py index 9785b0fbd9b4..b942bb86a0ff 100644 --- a/src/py/flwr/common/retry_invoker.py +++ b/src/py/flwr/common/retry_invoker.py @@ -20,8 +20,16 @@ import time from collections.abc import Generator, Iterable from dataclasses import dataclass +from logging import INFO, WARN from typing import Any, Callable, Optional, Union, cast +import grpc + +from flwr.common.constant import MAX_RETRY_DELAY +from flwr.common.logger import log +from flwr.proto.clientappio_pb2_grpc import ClientAppIoStub +from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub + def exponential( base_delay: float = 1, @@ -303,3 +311,62 @@ def giveup_check(_exception: Exception) -> bool: # Trigger success event try_call_event_handler(self.on_success) return ret + + +def _make_simple_grpc_retry_invoker() -> RetryInvoker: + """Create a simple gRPC retry invoker.""" + + def _on_sucess(retry_state: RetryState) -> None: + if retry_state.tries > 1: + log( + INFO, + "Connection successful after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + + def _on_backoff(retry_state: RetryState) -> None: + if retry_state.tries == 1: + log(WARN, "Connection attempt failed, retrying...") + else: + log( + WARN, + "Connection attempt failed, retrying in %.2f seconds", + retry_state.actual_wait, + ) + + def _on_giveup(retry_state: RetryState) -> None: + if retry_state.tries > 1: + log( + WARN, + "Giving up reconnection after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + + return RetryInvoker( + wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY), + recoverable_exceptions=grpc.RpcError, + max_tries=None, + max_time=None, + on_success=_on_sucess, + on_backoff=_on_backoff, + on_giveup=_on_giveup, + should_giveup=lambda e: e.code() != grpc.StatusCode.UNAVAILABLE, # type: ignore + ) + + +def _wrap_stub( + stub: Union[ServerAppIoStub, ClientAppIoStub], retry_invoker: RetryInvoker +) -> None: + """Wrap a gRPC stub with a retry invoker.""" + + def make_lambda(original_method: Any) -> Any: + return lambda *args, **kwargs: retry_invoker.invoke( + original_method, *args, **kwargs + ) + + for method_name in vars(stub): + method = getattr(stub, method_name) + if callable(method): + setattr(stub, method_name, make_lambda(method)) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 09318c32b704..f844aee57191 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -17,16 +17,16 @@ import time import warnings from collections.abc import Iterable -from logging import DEBUG, INFO, WARN, WARNING -from typing import Any, Optional, cast +from logging import DEBUG, WARNING +from typing import Optional, cast import grpc from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet -from flwr.common.constant import MAX_RETRY_DELAY, SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS +from flwr.common.constant import SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS from flwr.common.grpc import create_channel from flwr.common.logger import log -from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential +from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub from flwr.common.serde import message_from_taskres, message_to_taskins, run_from_proto from flwr.common.typing import Run from flwr.proto.node_pb2 import Node # pylint: disable=E0611 @@ -262,60 +262,3 @@ def close(self) -> None: return # Disconnect self._disconnect() - - -def _make_simple_grpc_retry_invoker() -> RetryInvoker: - """Create a simple gRPC retry invoker.""" - - def _on_sucess(retry_state: RetryState) -> None: - if retry_state.tries > 1: - log( - INFO, - "Connection successful after %.2f seconds and %s tries.", - retry_state.elapsed_time, - retry_state.tries, - ) - - def _on_backoff(retry_state: RetryState) -> None: - if retry_state.tries == 1: - log(WARN, "Connection attempt failed, retrying...") - else: - log( - WARN, - "Connection attempt failed, retrying in %.2f seconds", - retry_state.actual_wait, - ) - - def _on_giveup(retry_state: RetryState) -> None: - if retry_state.tries > 1: - log( - WARN, - "Giving up reconnection after %.2f seconds and %s tries.", - retry_state.elapsed_time, - retry_state.tries, - ) - - return RetryInvoker( - wait_gen_factory=lambda: exponential(max_delay=MAX_RETRY_DELAY), - recoverable_exceptions=grpc.RpcError, - max_tries=None, - max_time=None, - on_success=_on_sucess, - on_backoff=_on_backoff, - on_giveup=_on_giveup, - should_giveup=lambda e: e.code() != grpc.StatusCode.UNAVAILABLE, # type: ignore - ) - - -def _wrap_stub(stub: ServerAppIoStub, retry_invoker: RetryInvoker) -> None: - """Wrap the gRPC stub with a retry invoker.""" - - def make_lambda(original_method: Any) -> Any: - return lambda *args, **kwargs: retry_invoker.invoke( - original_method, *args, **kwargs - ) - - for method_name in vars(stub): - method = getattr(stub, method_name) - if callable(method): - setattr(stub, method_name, make_lambda(method))