diff --git a/src/otaclient_common/retry_task_map.py b/src/otaclient_common/retry_task_map.py index 920672382..f0112cffa 100644 --- a/src/otaclient_common/retry_task_map.py +++ b/src/otaclient_common/retry_task_map.py @@ -67,7 +67,7 @@ def __init__( self._retry_counter = itertools.count(start=1) self._retry_count = 0 self._concurrent_semaphore = threading.Semaphore(max_concurrent) - self._fut_queue: SimpleQueue[Future[Any] | None] = SimpleQueue() + self._fut_queue: SimpleQueue[Future[Any]] = SimpleQueue() # NOTE: leave two threads each for watchdog and dispatcher max_workers = ( @@ -77,35 +77,31 @@ def __init__( def _watchdog() -> None: """Watchdog watches exceeding of max_retry and max no_progress_timeout.""" - try: - while not self._shutdown: - if ( - self.max_total_retry - and self._retry_count > self.max_total_retry - ): - logger.warning(f"exceed {self.max_total_retry=}, abort") + while not self._shutdown: + if self.max_total_retry and self._retry_count > self.max_total_retry: + logger.warning(f"exceed {self.max_total_retry=}, abort") + return self.shutdown(wait=True) + + if callable(watchdog_func): + try: + watchdog_func() + except Exception as e: + logger.warning(f"custom watchdog func failed: {e!r}, abort") return self.shutdown(wait=True) - - if callable(watchdog_func): - try: - watchdog_func() - except Exception as e: - logger.warning(f"custom watchdog func failed: {e!r}, abort") - return self.shutdown(wait=True) - time.sleep(self.WATCH_DOG_CHECK_INTERVAL) - finally: - self._fut_queue.put_nowait(None) # always wakeup ensure_task + time.sleep(self.WATCH_DOG_CHECK_INTERVAL) self.submit(_watchdog) def _task_done_cb( self, fut: Future[Any], /, *, item: T, func: Callable[[T], Any] ) -> None: + # ------ on task succeeded ------ # if not fut.exception(): self._concurrent_semaphore.release() self._finished_task = next(self._finished_task_counter) return + # ------ on threadpool shutdown(by watchdog) ------ # if self._shutdown or self._broken: # wakeup dispatcher self._concurrent_semaphore.release() @@ -113,7 +109,6 @@ def _task_done_cb( # ------ on task failed ------ # self._retry_count = next(self._retry_counter) - # NOTE: not return the new fut! self._fut_queue.put_nowait(fut) with contextlib.suppress(Exception): @@ -161,14 +156,13 @@ def _dispatcher(): # ------ ensure all tasks are finished ------ # while self._total_task_num == 0 or self._finished_task != self._total_task_num: try: - _fut = self._fut_queue.get_nowait() + yield self._fut_queue.get_nowait() except Empty: time.sleep(self.ENSURE_TASKS_PULL_INTERVAL) continue - if self._shutdown or self._broken or _fut is None: + if self._shutdown or self._broken: logger.warning( f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}" ) raise TasksEnsureFailed - yield _fut