diff --git a/eth_retry/eth_retry.py b/eth_retry/eth_retry.py index 7c5b512..0b80df4 100644 --- a/eth_retry/eth_retry.py +++ b/eth_retry/eth_retry.py @@ -7,7 +7,7 @@ from json import JSONDecodeError from random import randrange from time import sleep -from typing import Awaitable, Callable, Optional, TypeVar, Union, overload +from typing import Callable, Optional, TypeVar, Union, overload import requests @@ -28,16 +28,9 @@ T = TypeVar("T") P = ParamSpec("P") -Function = Callable[P, T] -CoroutineFunction = Function[P, Awaitable[T]] -Decoratee = Union[Function[P, T], CoroutineFunction[P, T]] -@overload -def auto_retry(func: CoroutineFunction[P, T]) -> CoroutineFunction[P, T]: ... - - -def auto_retry(func: Function[P, T]) -> Function[P, T]: +def auto_retry(func: Callable[P, T]) -> Callable[P, T]: """ Decorator that will retry the function on: - ConnectionError @@ -55,59 +48,64 @@ def auto_retry(func: Function[P, T]) -> Function[P, T]: On repeat errors, will retry in increasing intervals. """ - @functools.wraps(func) - def auto_retry_wrap(*args: P.args, **kwargs: P.kwargs) -> T: - sleep_time = randrange(ENVS.MIN_SLEEP_TIME, ENVS.MAX_SLEEP_TIME) - failures = 0 - while True: - # Attempt to execute `func` and return response - try: - return func(*args, **kwargs) # type: ignore - except Exception as e: - if not should_retry(e, failures): - raise - if failures > ENVS.ETH_RETRY_SUPPRESS_LOGS: - logger.warning(f"{str(e)} [{failures}]") - if ENVS.ETH_RETRY_DEBUG: - logger.exception(e) - - # Attempt failed, sleep time. - failures += 1 - if ENVS.ETH_RETRY_DEBUG: - logger.info(f"sleeping {round(failures * sleep_time, 2)} seconds.") - sleep(failures * sleep_time) - - @functools.wraps(func) - async def auto_retry_wrap_async(*args: P.args, **kwargs: P.kwargs) -> T: - sleep_time = randrange(ENVS.MIN_SLEEP_TIME, ENVS.MAX_SLEEP_TIME) - failures = 0 - while True: - try: - return await func(*args, **kwargs) # type: ignore - except asyncio.exceptions.TimeoutError as e: - logger.warning( - f"asyncio timeout [{failures}] {_get_caller_details_from_stack()}" - ) + + if asyncio.iscoroutinefunction(func): + + @functools.wraps(func) + async def auto_retry_wrap_async(*args: P.args, **kwargs: P.kwargs) -> T: + sleep_time = randrange(ENVS.MIN_SLEEP_TIME, ENVS.MAX_SLEEP_TIME) + failures = 0 + while True: + try: + return await func(*args, **kwargs) # type: ignore + except asyncio.exceptions.TimeoutError as e: + logger.warning( + f"asyncio timeout [{failures}] {_get_caller_details_from_stack()}" + ) + if ENVS.ETH_RETRY_DEBUG: + logger.exception(e) + continue + except Exception as e: + if not should_retry(e, failures): + raise + if failures > ENVS.ETH_RETRY_SUPPRESS_LOGS: + logger.warning(f"{str(e)} [{failures}]") + if ENVS.ETH_RETRY_DEBUG: + logger.exception(e) + + # Attempt failed, sleep time. + failures += 1 if ENVS.ETH_RETRY_DEBUG: - logger.exception(e) - continue - except Exception as e: - if not should_retry(e, failures): - raise - if failures > ENVS.ETH_RETRY_SUPPRESS_LOGS: - logger.warning(f"{str(e)} [{failures}]") + logger.info(f"sleeping {round(failures * sleep_time, 2)} seconds.") + await asyncio.sleep(failures * sleep_time) + + return auto_retry_wrap_async # type: ignore [return-value] + + else: + + @functools.wraps(func) + def auto_retry_wrap(*args: P.args, **kwargs: P.kwargs) -> T: + sleep_time = randrange(ENVS.MIN_SLEEP_TIME, ENVS.MAX_SLEEP_TIME) + failures = 0 + while True: + # Attempt to execute `func` and return response + try: + return func(*args, **kwargs) + except Exception as e: + if not should_retry(e, failures): + raise + if failures > ENVS.ETH_RETRY_SUPPRESS_LOGS: + logger.warning(f"{str(e)} [{failures}]") + if ENVS.ETH_RETRY_DEBUG: + logger.exception(e) + + # Attempt failed, sleep time. + failures += 1 if ENVS.ETH_RETRY_DEBUG: - logger.exception(e) - - # Attempt failed, sleep time. - failures += 1 - if ENVS.ETH_RETRY_DEBUG: - logger.info(f"sleeping {round(failures * sleep_time, 2)} seconds.") - await asyncio.sleep(failures * sleep_time) - - if asyncio.iscoroutinefunction(func): - return auto_retry_wrap_async - return auto_retry_wrap + logger.info(f"sleeping {round(failures * sleep_time, 2)} seconds.") + sleep(failures * sleep_time) + + return auto_retry_wrap def should_retry(e: Exception, failures: int) -> bool: