diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py index a7b512e5d1a9..a60fff57e7bf 100644 --- a/src/py/flwr/common/retry_invoker.py +++ b/src/py/flwr/common/retry_invoker.py @@ -107,7 +107,7 @@ class RetryInvoker: Parameters ---------- - wait_strategy: Generator[float, None, None] + wait_factory: Callable[[], Generator[float, None, None]] A generator yielding successive wait times in seconds. If the generator is finite, the giveup event will be triggered when the generator raises `StopIteration`. @@ -129,11 +129,11 @@ class RetryInvoker: data class object detailing the invocation. on_giveup: Optional[Callable[[RetryState], None]] (default: None) A callable to be executed in the event that `max_tries` or `max_time` is - exceeded, `should_giveup` returns True, or `wait_strategy` generator raises + exceeded, `should_giveup` returns True, or `wait_factory()` generator raises `StopInteration`. The parameter is a data class object detailing the invocation. jitter: Optional[Callable[[float], float]] (default: full_jitter) - A function of the value yielded by `wait_strategy` returning the actual time + A function of the value yielded by `wait_factory()` returning the actual time to wait. This function helps distribute wait times stochastically to avoid timing collisions across concurrent clients. Wait times are jittered by default using the `full_jitter` function. To disable jittering, pass @@ -145,20 +145,20 @@ class RetryInvoker: Examples -------- - Initialize a `RetryInvoker` with exponential backoff and call a function: + Initialize a `RetryInvoker` with exponential backoff and invoke a function: >>> invoker = RetryInvoker( - >>> exponential(), - >>> grpc.RpcError, - >>> max_tries=3, - >>> max_time=None, - >>> ) + ... exponential, # Or use `lambda: exponential(3, 2)` to pass arguments + ... grpc.RpcError, + ... max_tries=3, + ... max_time=None, + ... ) >>> invoker.invoke(my_func, arg1, arg2, kw1=kwarg1) """ def __init__( self, - wait_strategy: Generator[float, None, None], + wait_factory: Callable[[], Generator[float, None, None]], recoverable_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]], max_tries: Optional[int], max_time: Optional[float], @@ -169,7 +169,7 @@ def __init__( jitter: Optional[Callable[[float], float]] = full_jitter, should_giveup: Optional[Callable[[Exception], bool]] = None, ) -> None: - self.wait_strategy = wait_strategy + self.wait_factory = wait_factory self.recoverable_exceptions = recoverable_exceptions self.max_tries = max_tries self.max_time = max_time @@ -183,8 +183,8 @@ def __init__( def invoke( self, target: Callable[..., Any], - *args: Tuple[Any, ...], - **kwargs: Dict[str, Any], + *args: Any, + **kwargs: Any, ) -> Any: """Safely invoke the provided callable with retry mechanisms. @@ -212,12 +212,12 @@ def invoke( ------ Exception If the number of tries exceeds `max_tries`, if the total time - exceeds `max_time`, if `wait_strategy` generator raises `StopInteration`, + exceeds `max_time`, if `wait_factory()` generator raises `StopInteration`, or if the `should_giveup` returns True for a raised exception. Notes ----- - The time between retries is determined by the provided `wait_strategy` + The time between retries is determined by the provided `wait_factory()` generator and can optionally be jittered using the `jitter` function. The recoverable exceptions that trigger a retry, as well as conditions to stop retries, are also determined by the class's initialization parameters. @@ -230,6 +230,7 @@ def try_call_event_handler( handler(cast(RetryState, ref_state[0])) try_cnt = 0 + wait_generator = self.wait_factory() start = time.time() ref_state: List[Optional[RetryState]] = [None] @@ -265,7 +266,7 @@ def giveup_check(_exception: Exception) -> bool: raise try: - wait_time = next(self.wait_strategy) + wait_time = next(wait_generator) if self.jitter is not None: wait_time = self.jitter(wait_time) if self.max_time is not None: diff --git a/src/py/flwr/common/retry_invoker_test.py b/src/py/flwr/common/retry_invoker_test.py index 5f6dab49ce1c..e67c0641e2ba 100644 --- a/src/py/flwr/common/retry_invoker_test.py +++ b/src/py/flwr/common/retry_invoker_test.py @@ -54,7 +54,7 @@ def test_successful_invocation() -> None: backoff_handler = Mock() giveup_handler = Mock() invoker = RetryInvoker( - constant(0.1), + lambda: constant(0.1), ValueError, max_tries=None, max_time=None, @@ -77,7 +77,7 @@ def test_failure() -> None: """Check termination when unexpected exception is raised.""" # Prepare # `constant([0.1])` generator will raise `StopIteration` after one iteration. - invoker = RetryInvoker(constant(0.1), TypeError, None, None) + invoker = RetryInvoker(lambda: constant(0.1), TypeError, None, None) # Execute and Assert with pytest.raises(ValueError): @@ -88,7 +88,11 @@ def test_failure_two_exceptions(mock_sleep: MagicMock) -> None: """Verify one retry on a specified iterable of exceptions.""" # Prepare invoker = RetryInvoker( - constant(0.1), (TypeError, ValueError), max_tries=2, max_time=None, jitter=None + lambda: constant(0.1), + (TypeError, ValueError), + max_tries=2, + max_time=None, + jitter=None, ) # Execute and Assert @@ -101,7 +105,7 @@ def test_backoff_on_failure(mock_sleep: MagicMock) -> None: """Verify one retry on specified exception.""" # Prepare # `constant([0.1])` generator will raise `StopIteration` after one iteration. - invoker = RetryInvoker(constant([0.1]), ValueError, None, None, jitter=None) + invoker = RetryInvoker(lambda: constant([0.1]), ValueError, None, None, jitter=None) # Execute and Assert with pytest.raises(ValueError): @@ -114,7 +118,7 @@ def test_max_tries(mock_sleep: MagicMock) -> None: # Prepare # Disable `jitter` to ensure 0.1s wait time. invoker = RetryInvoker( - constant(0.1), ValueError, max_tries=2, max_time=None, jitter=None + lambda: constant(0.1), ValueError, max_tries=2, max_time=None, jitter=None ) # Execute and Assert @@ -132,7 +136,9 @@ def test_max_time(mock_time: MagicMock, mock_sleep: MagicMock) -> None: 0.0, 3.0, ] - invoker = RetryInvoker(constant(2), ValueError, max_tries=None, max_time=2.5) + invoker = RetryInvoker( + lambda: constant(2), ValueError, max_tries=None, max_time=2.5 + ) # Execute and Assert with pytest.raises(ValueError): @@ -148,7 +154,7 @@ def test_event_handlers() -> None: backoff_handler = Mock() giveup_handler = Mock() invoker = RetryInvoker( - constant(0.1), + lambda: constant(0.1), ValueError, max_tries=2, max_time=None, @@ -173,7 +179,7 @@ def should_give_up(exc: Exception) -> bool: return isinstance(exc, ValueError) invoker = RetryInvoker( - constant(0.1), ValueError, None, None, should_giveup=should_give_up + lambda: constant(0.1), ValueError, None, None, should_giveup=should_give_up ) # Execute and Assert