From 32bae00cd20fff420fddbf91db75f8f851c790d0 Mon Sep 17 00:00:00 2001 From: "bodong.yang" Date: Sun, 23 Jun 2024 04:37:32 +0000 Subject: [PATCH] rtm: watchdog becomes a method, fut gen becomes a method --- src/otaclient_common/retry_task_map.py | 97 +++++++++++++++----------- 1 file changed, 56 insertions(+), 41 deletions(-) diff --git a/src/otaclient_common/retry_task_map.py b/src/otaclient_common/retry_task_map.py index 3d413e6ec..39f546f4d 100644 --- a/src/otaclient_common/retry_task_map.py +++ b/src/otaclient_common/retry_task_map.py @@ -45,7 +45,6 @@ def __init__( thread_name_prefix: str = "", watchdog_func: Optional[Callable] = None, watchdog_check_interval: int = 3, # seconds - ensure_tasks_pull_interval: int = 1, # second initializer: Callable[..., Any] | None = None, initargs: tuple = (), ) -> None: @@ -59,11 +58,11 @@ def __init__( watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None. watchdog_check_interval (int): Defaults to 3(seconds). - ensure_tasks_pull_interval (int): Defaults to 1(second). + initializer (Callable[..., Any] | None): The same param passed through to ThreadPoolExecutor. + Defaults to None. + initargs (tuple): The same param passed through to ThreadPoolExecutor. + Defaults to (). """ - self.max_total_retry = max_total_retry - self.ensure_tasks_pull_interval = ensure_tasks_pull_interval - self._start_lock, self._started = threading.Lock(), False self._total_task_num = 0 self._finished_task_counter = itertools.count(start=1) @@ -80,23 +79,32 @@ def __init__( initargs=initargs, ) - def _watchdog() -> None: - """Watchdog will shutdown the threadpool on certain conditions being met.""" - while not self._shutdown and not concurrent_fut_thread._shutdown: - if self.max_total_retry and self._retry_count > self.max_total_retry: - logger.warning(f"exceed {self.max_total_retry=}, abort") - return self.shutdown(wait=True) + if max_total_retry or callable(watchdog_func): + threading.Thread( + target=self._watchdog, + args=(max_total_retry, watchdog_func, watchdog_check_interval), + daemon=True, + ).start() - if callable(watchdog_func): - try: - watchdog_func() - except Exception as e: - logger.warning(f"custom watchdog func failed: {e!r}, abort") - return self.shutdown(wait=True) - time.sleep(watchdog_check_interval) + def _watchdog( + self, + max_retry: int | None, + watchdog_func: Callable[..., Any] | None, + interval: int, + ) -> None: + """Watchdog will shutdown the threadpool on certain conditions being met.""" + while not self._shutdown and not concurrent_fut_thread._shutdown: + if max_retry and self._retry_count > max_retry: + logger.warning(f"exceed {max_retry=}, abort") + return self.shutdown(wait=True) - if self.max_total_retry or callable(watchdog_func): - threading.Thread(target=_watchdog, daemon=True).start() + if callable(watchdog_func): + try: + watchdog_func() + except Exception as e: + logger.warning(f"custom watchdog func failed: {e!r}, abort") + return self.shutdown(wait=True) + time.sleep(interval) def _task_done_cb( self, fut: Future[Any], /, *, item: T, func: Callable[[T], Any] @@ -120,8 +128,31 @@ def _task_done_cb( partial(self._task_done_cb, item=item, func=func) ) + def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]: + while True: + if self._shutdown or self._broken or concurrent_fut_thread._shutdown: + logger.warning( + f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}" + ) + raise TasksEnsureFailed + + try: + yield self._fut_queue.get_nowait() + except Empty: + if ( + self._total_task_num == 0 + or self._finished_task != self._total_task_num + ): + time.sleep(interval) + continue + return + def ensure_tasks( - self, func: Callable[[T], RT], iterable: Iterable[T] + self, + func: Callable[[T], RT], + iterable: Iterable[T], + *, + ensure_tasks_pull_interval: int = 1, ) -> Generator[Future[RT], None, None]: """Ensure all the items in are processed by in the pool. @@ -165,23 +196,7 @@ def _dispatcher() -> None: threading.Thread(target=_dispatcher, daemon=True).start() # ------ ensure all tasks are finished ------ # - while True: - if self._shutdown or self._broken or concurrent_fut_thread._shutdown: - logger.warning( - f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}" - ) - try: - raise TasksEnsureFailed - finally: - del self, func, iterable - - try: - yield self._fut_queue.get_nowait() - except Empty: - if ( - self._total_task_num == 0 - or self._finished_task != self._total_task_num - ): - time.sleep(self.ensure_tasks_pull_interval) - continue - return + # NOTE: also see base.Executor.map method, let the yield hidden in + # a generator so that the first fut will be dispatched before + # we start to get from fut_queue. + return self._fut_gen(ensure_tasks_pull_interval)