diff --git a/src/otaclient_common/retry_task_map.py b/src/otaclient_common/retry_task_map.py index 221fc9521..5ce569fdc 100644 --- a/src/otaclient_common/retry_task_map.py +++ b/src/otaclient_common/retry_task_map.py @@ -65,8 +65,6 @@ def __init__( """ self._start_lock, self._started = threading.Lock(), False self._total_task_num = 0 - self._finished_task_counter = itertools.count(start=1) - self._finished_task = 0 self._retry_counter = itertools.count(start=1) self._retry_count = 0 self._concurrent_semaphore = threading.Semaphore(max_concurrent) @@ -112,40 +110,33 @@ def _task_done_cb( self._concurrent_semaphore.release() # always release se first self._fut_queue.put_nowait(fut) - # ------ on task succeeded ------ # - if not fut.exception(): - self._finished_task = next(self._finished_task_counter) - return - - # ------ on threadpool shutdown(by watchdog) ------ # - if self._shutdown or self._broken: - return - # ------ on task failed ------ # - self._retry_count = next(self._retry_counter) - with contextlib.suppress(Exception): # on threadpool shutdown - self.submit(func, item).add_done_callback( - partial(self._task_done_cb, item=item, func=func) - ) + if fut.exception(): + self._retry_count = next(self._retry_counter) + with contextlib.suppress(Exception): # on threadpool shutdown + self.submit(func, item).add_done_callback( + partial(self._task_done_cb, item=item, func=func) + ) def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]: + finished_tasks = 0 while True: if self._shutdown or self._broken or concurrent_fut_thread._shutdown: logger.warning( - f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}" + f"failed to ensure all tasks, {finished_tasks=}, {self._total_task_num=}" ) raise TasksEnsureFailed try: - yield self._fut_queue.get_nowait() + done_fut = self._fut_queue.get_nowait() + if not done_fut.exception(): + finished_tasks += 1 + yield done_fut except Empty: - if ( - self._total_task_num == 0 - or self._finished_task != self._total_task_num - ): + if self._total_task_num == 0 or finished_tasks != self._total_task_num: time.sleep(interval) continue - return + return # all tasks finished and futs are yielded def ensure_tasks( self,