diff --git a/src/otaclient_common/retry_task_map.py b/src/otaclient_common/retry_task_map.py index 1daf5294e..39f546f4d 100644 --- a/src/otaclient_common/retry_task_map.py +++ b/src/otaclient_common/retry_task_map.py @@ -15,7 +15,7 @@ from __future__ import annotations -import atexit +import concurrent.futures.thread as concurrent_fut_thread import contextlib import itertools import logging @@ -30,16 +30,6 @@ logger = logging.getLogger(__name__) -_retry_task_map_global_shutdown = False - - -def _python_exit(): - global _retry_task_map_global_shutdown - _retry_task_map_global_shutdown = True - - -atexit.register(_python_exit) - class TasksEnsureFailed(Exception): """Exception for tasks ensuring failed.""" @@ -55,7 +45,6 @@ def __init__( thread_name_prefix: str = "", watchdog_func: Optional[Callable] = None, watchdog_check_interval: int = 3, # seconds - ensure_tasks_pull_interval: int = 1, # second initializer: Callable[..., Any] | None = None, initargs: tuple = (), ) -> None: @@ -69,11 +58,11 @@ def __init__( watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None. watchdog_check_interval (int): Defaults to 3(seconds). - ensure_tasks_pull_interval (int): Defaults to 1(second). + initializer (Callable[..., Any] | None): The same param passed through to ThreadPoolExecutor. + Defaults to None. + initargs (tuple): The same param passed through to ThreadPoolExecutor. + Defaults to (). """ - self.max_total_retry = max_total_retry - self.ensure_tasks_pull_interval = ensure_tasks_pull_interval - self._start_lock, self._started = threading.Lock(), False self._total_task_num = 0 self._finished_task_counter = itertools.count(start=1) @@ -90,22 +79,23 @@ def __init__( initargs=initargs, ) - threading.Thread( - target=self._watchdog, - args=(watchdog_func, watchdog_check_interval), - daemon=True, - ).start() + if max_total_retry or callable(watchdog_func): + threading.Thread( + target=self._watchdog, + args=(max_total_retry, watchdog_func, watchdog_check_interval), + daemon=True, + ).start() def _watchdog( - self, watchdog_func: Callable[..., Any], watchdog_check_interval: int + self, + max_retry: int | None, + watchdog_func: Callable[..., Any] | None, + interval: int, ) -> None: """Watchdog will shutdown the threadpool on certain conditions being met.""" - if not self.max_total_retry and not callable(watchdog_func): - return # no need to run watchdog thread if not checks are performed - - while not self._shutdown and not _retry_task_map_global_shutdown: - if self.max_total_retry and self._retry_count > self.max_total_retry: - logger.warning(f"exceed {self.max_total_retry=}, abort") + while not self._shutdown and not concurrent_fut_thread._shutdown: + if max_retry and self._retry_count > max_retry: + logger.warning(f"exceed {max_retry=}, abort") return self.shutdown(wait=True) if callable(watchdog_func): @@ -114,13 +104,13 @@ def _watchdog( except Exception as e: logger.warning(f"custom watchdog func failed: {e!r}, abort") return self.shutdown(wait=True) - time.sleep(watchdog_check_interval) + time.sleep(interval) def _task_done_cb( self, fut: Future[Any], /, *, item: T, func: Callable[[T], Any] ) -> None: - self._concurrent_semaphore.release() # NOTE: always release se first self._fut_queue.put_nowait(fut) + self._concurrent_semaphore.release() # ------ on task succeeded ------ # if not fut.exception(): @@ -128,7 +118,7 @@ def _task_done_cb( return # ------ on threadpool shutdown(by watchdog) ------ # - if self._shutdown or _retry_task_map_global_shutdown: + if self._shutdown or self._broken: return # ------ on task failed ------ # @@ -138,23 +128,31 @@ def _task_done_cb( partial(self._task_done_cb, item=item, func=func) ) - def _dispatcher(self, func: Callable[[T], RT], iterable: Iterable[T]) -> None: - try: - for _tasks_count, item in enumerate(iterable, start=1): - self._concurrent_semaphore.acquire() - # NOTE: on pool shutdown, the submit method will throw exceptions - fut = self.submit(func, item) - fut.add_done_callback(partial(self._task_done_cb, item=item, func=func)) - except Exception as e: - logger.error(f"tasks dispatcher failed: {e!r}, abort") - self.shutdown(wait=True) - return + def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]: + while True: + if self._shutdown or self._broken or concurrent_fut_thread._shutdown: + logger.warning( + f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}" + ) + raise TasksEnsureFailed - self._total_task_num = _tasks_count - logger.info(f"finish dispatch {_tasks_count} tasks") + try: + yield self._fut_queue.get_nowait() + except Empty: + if ( + self._total_task_num == 0 + or self._finished_task != self._total_task_num + ): + time.sleep(interval) + continue + return def ensure_tasks( - self, func: Callable[[T], RT], iterable: Iterable[T] + self, + func: Callable[[T], RT], + iterable: Iterable[T], + *, + ensure_tasks_pull_interval: int = 1, ) -> Generator[Future[RT], None, None]: """Ensure all the items in are processed by in the pool. @@ -171,33 +169,34 @@ def ensure_tasks( The Future instance of each processed tasks. """ with self._start_lock: - if self._started or self._shutdown or _retry_task_map_global_shutdown: + if self._started: try: - raise ValueError( - "pool shutdowned or ensure_tasks cannot be started more than once" - ) + raise ValueError("ensure_tasks cannot be started more than once") finally: # do not hold refs to input params del self, func, iterable self._started = True # ------ dispatch tasks from iterable ------ # - threading.Thread( - target=self._dispatcher, args=(func, iterable), daemon=True - ).start() + def _dispatcher() -> None: + try: + for _tasks_count, item in enumerate(iterable, start=1): + self._concurrent_semaphore.acquire() + fut = self.submit(func, item) + fut.add_done_callback( + partial(self._task_done_cb, item=item, func=func) + ) + except Exception as e: + logger.error(f"tasks dispatcher failed: {e!r}, abort") + self.shutdown(wait=True) + return - # ------ ensure all tasks are finished ------ # - while self._total_task_num == 0 or self._finished_task != self._total_task_num: - # shutdown by upper caller or interpreter exits - if self._shutdown or _retry_task_map_global_shutdown: - _err_msg = f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}" - logger.warning(_err_msg) + self._total_task_num = _tasks_count + logger.info(f"finish dispatch {_tasks_count} tasks") - try: - raise TasksEnsureFailed(_err_msg) - finally: - del self, func, iterable + threading.Thread(target=_dispatcher, daemon=True).start() - try: - yield self._fut_queue.get_nowait() - except Empty: - time.sleep(self.ensure_tasks_pull_interval) + # ------ ensure all tasks are finished ------ # + # NOTE: also see base.Executor.map method, let the yield hidden in + # a generator so that the first fut will be dispatched before + # we start to get from fut_queue. + return self._fut_gen(ensure_tasks_pull_interval)