From 19c4c5c0d393b0f5b2fec0811fdada48cc6b4047 Mon Sep 17 00:00:00 2001 From: Bodong Yang <86948717+Bodong-Yang@users.noreply.github.com> Date: Tue, 3 Dec 2024 12:14:12 +0900 Subject: [PATCH] feat(retry_task_map): introduce worker-thread-scope backoff wait on failed tasks rescheduling. (#450) This PR introduces worker-thread-scope backoff wait on failed tasks rescheduling for otaclient_common.retry_task_map. For each worker thread, a coutinues_failed_count is maintained, when rescheduling failed tasks, now we will wait with backoff before put the failed task back into task queue. This will help lower the CPU usage a lot when the network is fully disconnected. --- src/otaclient_common/retry_task_map.py | 77 ++++++++++++++++--- .../test_retry_task_map.py | 12 ++- 2 files changed, 77 insertions(+), 12 deletions(-) diff --git a/src/otaclient_common/retry_task_map.py b/src/otaclient_common/retry_task_map.py index 414076476..5d71b6623 100644 --- a/src/otaclient_common/retry_task_map.py +++ b/src/otaclient_common/retry_task_map.py @@ -26,8 +26,13 @@ from queue import Empty, SimpleQueue from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Optional +from typing_extensions import ParamSpec + +from otaclient_common.common import wait_with_backoff from otaclient_common.typing import RT, T +P = ParamSpec("P") + logger = logging.getLogger(__name__) @@ -35,6 +40,9 @@ class TasksEnsureFailed(Exception): """Exception for tasks ensuring failed.""" +CONTINUES_FAILURE_COUNT_ATTRNAME = "continues_failed_count" + + class _ThreadPoolExecutorWithRetry(ThreadPoolExecutor): def __init__( @@ -47,6 +55,8 @@ def __init__( watchdog_check_interval: int = 3, # seconds initializer: Callable[..., Any] | None = None, initargs: tuple = (), + backoff_factor: float = 0.01, + backoff_max: float = 1, ) -> None: self._start_lock, self._started = threading.Lock(), False self._total_task_num = 0 @@ -57,6 +67,8 @@ def __init__( no tasks, the task execution gen should stop immediately. 3. only value >=1 is valid. """ + self.backoff_factor = backoff_factor + self.backoff_max = backoff_max self._retry_counter = itertools.count(start=1) self._retry_count = 0 @@ -70,13 +82,26 @@ def __init__( if callable(watchdog_func): self._checker_funcs.append(watchdog_func) + self._thread_local = threading.local() super().__init__( max_workers=max_workers, thread_name_prefix=thread_name_prefix, - initializer=initializer, + initializer=self._rtm_initializer_gen(initializer), initargs=initargs, ) + def _rtm_initializer_gen( + self, _input_initializer: Callable[P, RT] | None + ) -> Callable[P, None]: + def _real_initializer(*args: P.args, **kwargs: P.kwargs) -> None: + _thread_local = self._thread_local + setattr(_thread_local, CONTINUES_FAILURE_COUNT_ATTRNAME, 0) + + if callable(_input_initializer): + _input_initializer(*args, **kwargs) + + return _real_initializer + def _max_retry_check(self, max_total_retry: int) -> None: if self._retry_count > max_total_retry: raise TasksEnsureFailed("exceed max retry count, abort") @@ -110,16 +135,40 @@ def _task_done_cb( return # on shutdown, no need to put done fut into fut_queue self._fut_queue.put_nowait(fut) - # ------ on task failed ------ # - if fut.exception(): - self._retry_count = next(self._retry_counter) - try: # try to re-schedule the failed task - self.submit(func, item).add_done_callback( - partial(self._task_done_cb, item=item, func=func) - ) - except Exception: # if re-schedule doesn't happen, release se - self._concurrent_semaphore.release() - else: # release semaphore when succeeded + _thread_local = self._thread_local + + # release semaphore only on success + # reset continues failure count on success + if not fut.exception(): + self._concurrent_semaphore.release() + _thread_local.continues_failed_count = 0 + return + + # NOTE: when for some reason the continues_failed_count is gone, + # handle the AttributeError here and re-assign the count. + try: + _continues_failed_count = getattr( + _thread_local, CONTINUES_FAILURE_COUNT_ATTRNAME + ) + except AttributeError: + _continues_failed_count = 0 + + _continues_failed_count += 1 + setattr( + _thread_local, CONTINUES_FAILURE_COUNT_ATTRNAME, _continues_failed_count + ) + wait_with_backoff( + _continues_failed_count, + _backoff_factor=self.backoff_factor, + _backoff_max=self.backoff_max, + ) + + self._retry_count = next(self._retry_counter) + try: # try to re-schedule the failed task + self.submit(func, item).add_done_callback( + partial(self._task_done_cb, item=item, func=func) + ) + except Exception: # if re-schedule doesn't happen, release se self._concurrent_semaphore.release() def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]: @@ -221,6 +270,8 @@ def __init__( watchdog_check_interval: int = 3, # seconds initializer: Callable[..., Any] | None = None, initargs: tuple = (), + backoff_factor: float = 0.01, + backoff_max: float = 1, ) -> None: """Initialize a ThreadPoolExecutorWithRetry instance. @@ -236,6 +287,10 @@ def __init__( Defaults to None. initargs (tuple): The same param passed through to ThreadPoolExecutor. Defaults to (). + backoff_factor (float): The backoff factor on thread-scope backoff wait for failed tasks rescheduling. + Defaults to 0.01. + backoff_max (float): The backoff max on thread-scope backoff wait for failed tasks rescheduling. + Defaults to 1. """ raise NotImplementedError diff --git a/tests/test_otaclient_common/test_retry_task_map.py b/tests/test_otaclient_common/test_retry_task_map.py index c9c1aa24d..ca65e2304 100644 --- a/tests/test_otaclient_common/test_retry_task_map.py +++ b/tests/test_otaclient_common/test_retry_task_map.py @@ -32,6 +32,8 @@ MAX_CONCURRENT = 128 MAX_WAIT_BEFORE_SUCCESS = 10 THREAD_INIT_MSG = "thread init message" +BACKOFF_FACTOR = 0.001 # for faster test +BACKOFF_MAX = 0.1 class _RetryTaskMapTestErr(Exception): @@ -47,7 +49,7 @@ def _thread_initializer(msg: str) -> None: class TestRetryTaskMap: @pytest.fixture(autouse=True) - def setup(self): + def setup(self) -> None: self._start_time = time.time() self._success_wait_dict = { idx: random.randint(0, MAX_WAIT_BEFORE_SUCCESS) @@ -83,6 +85,8 @@ def _exit_on_exceed_max_count(): watchdog_func=_exit_on_exceed_max_count, initializer=_thread_initializer, initargs=(THREAD_INIT_MSG,), + backoff_factor=BACKOFF_FACTOR, + backoff_max=BACKOFF_MAX, ) as executor: with pytest.raises(retry_task_map.TasksEnsureFailed): for _fut in executor.ensure_tasks( @@ -99,6 +103,8 @@ def test_retry_exceed_retry_limit(self): max_total_retry=MAX_TOTAL_RETRY, initializer=_thread_initializer, initargs=(THREAD_INIT_MSG,), + backoff_factor=BACKOFF_FACTOR, + backoff_max=BACKOFF_MAX, ) as executor: with pytest.raises(retry_task_map.TasksEnsureFailed): for _fut in executor.ensure_tasks( @@ -115,6 +121,8 @@ def test_retry_finally_succeeded(self): max_concurrent=MAX_CONCURRENT, initializer=_thread_initializer, initargs=(THREAD_INIT_MSG,), + backoff_factor=BACKOFF_FACTOR, + backoff_max=BACKOFF_MAX, ) as executor: for _fut in executor.ensure_tasks( self.workload_failed_and_then_succeed, range(TASKS_COUNT) @@ -130,6 +138,8 @@ def test_succeeded_in_one_try(self): max_concurrent=MAX_CONCURRENT, initializer=_thread_initializer, initargs=(THREAD_INIT_MSG,), + backoff_factor=BACKOFF_FACTOR, + backoff_max=BACKOFF_MAX, ) as executor: for _fut in executor.ensure_tasks( self.workload_succeed, range(TASKS_COUNT)