From 1553e46d5616914190e751ae48309158f3717301 Mon Sep 17 00:00:00 2001 From: Bodong Yang Date: Wed, 5 Jun 2024 14:05:30 +0000 Subject: [PATCH] dispatcher: shutdown the threadpool on exception --- src/otaclient_common/retry_task_map.py | 44 ++++++++++++++++---------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/src/otaclient_common/retry_task_map.py b/src/otaclient_common/retry_task_map.py index f0112cffa..47dccb2c1 100644 --- a/src/otaclient_common/retry_task_map.py +++ b/src/otaclient_common/retry_task_map.py @@ -32,14 +32,11 @@ class TasksEnsureFailed(Exception): - pass + """Exception for tasks ensuring failed.""" class ThreadPoolExecutorWithRetry(ThreadPoolExecutor): - WATCH_DOG_CHECK_INTERVAL = 3 - ENSURE_TASKS_PULL_INTERVAL = 1 - def __init__( self, max_concurrent: int, @@ -47,18 +44,23 @@ def __init__( max_total_retry: Optional[int] = None, thread_name_prefix: str = "", watchdog_func: Optional[Callable] = None, + watchdog_check_interval: int = 3, # seconds + ensure_tasks_pull_interval: int = 1, # second ) -> None: """Initialize a ThreadPoolExecutorWithRetry instance. Args: - max_concurrent (int, optional): How many tasks should be kept in the memory. + max_concurrent (int): Limit the number pending scheduled tasks. max_workers (Optional[int], optional): Max number of worker threads in the pool. Defaults to None. max_total_retry (Optional[int], optional): Max total retry counts before abort. Defaults to None. thread_name_prefix (str, optional): Defaults to "". - watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, break threadpool when - this func raises exception. Defaults to None. + watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when + this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None. + watchdog_check_interval (int): Defaults to 3(seconds). + ensure_tasks_pull_interval (int): Defaults to 1(second). """ self.max_total_retry = max_total_retry + self.ensure_tasks_pull_interval = ensure_tasks_pull_interval self._start_lock, self._started = threading.Lock(), False self._total_task_num = 0 @@ -88,7 +90,7 @@ def _watchdog() -> None: except Exception as e: logger.warning(f"custom watchdog func failed: {e!r}, abort") return self.shutdown(wait=True) - time.sleep(self.WATCH_DOG_CHECK_INTERVAL) + time.sleep(watchdog_check_interval) self.submit(_watchdog) @@ -111,7 +113,7 @@ def _task_done_cb( self._retry_count = next(self._retry_counter) self._fut_queue.put_nowait(fut) - with contextlib.suppress(Exception): + with contextlib.suppress(Exception): # on threadpool shutdown self.submit(func, item).add_done_callback( partial(self._task_done_cb, item=item, func=func) ) @@ -141,15 +143,23 @@ def ensure_tasks( self._started = True # ------ dispatch tasks from iterable ------ # - def _dispatcher(): + def _dispatcher() -> None: _fut_queue = self._fut_queue - for _tasks_count, item in enumerate(iterable, start=1): - self._concurrent_semaphore.acquire() - fut = self.submit(func, item) - fut.add_done_callback(partial(self._task_done_cb, item=item, func=func)) - _fut_queue.put_nowait(fut) + try: + for _tasks_count, item in enumerate(iterable, start=1): + self._concurrent_semaphore.acquire() + fut = self.submit(func, item) + fut.add_done_callback( + partial(self._task_done_cb, item=item, func=func) + ) + _fut_queue.put_nowait(fut) + except Exception as e: + logger.error(f"tasks dispatcher failed: {e!r}, abort") + self.shutdown(wait=True) + return + self._total_task_num = _tasks_count - logger.info(f"finish dispatch {_tasks_count} of tasks") + logger.info(f"finish dispatch {_tasks_count} tasks") self.submit(_dispatcher) @@ -158,7 +168,7 @@ def _dispatcher(): try: yield self._fut_queue.get_nowait() except Empty: - time.sleep(self.ENSURE_TASKS_PULL_INTERVAL) + time.sleep(self.ensure_tasks_pull_interval) continue if self._shutdown or self._broken: