diff --git a/src/otaclient_common/retry_task_map.py b/src/otaclient_common/retry_task_map.py index 5e0e964de..055cb4bbf 100644 --- a/src/otaclient_common/retry_task_map.py +++ b/src/otaclient_common/retry_task_map.py @@ -19,7 +19,6 @@ import contextlib import itertools import logging -import os import threading import time from concurrent.futures import Future, ThreadPoolExecutor @@ -46,7 +45,8 @@ def __init__( thread_name_prefix: str = "", watchdog_func: Optional[Callable] = None, watchdog_check_interval: int = 3, # seconds - ensure_tasks_pull_interval: int = 1, # second + initializer: Callable[..., Any] | None = None, + initargs: tuple = (), ) -> None: """Initialize a ThreadPoolExecutorWithRetry instance. @@ -58,69 +58,92 @@ def __init__( watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None. watchdog_check_interval (int): Defaults to 3(seconds). - ensure_tasks_pull_interval (int): Defaults to 1(second). + initializer (Callable[..., Any] | None): The same param passed through to ThreadPoolExecutor. + Defaults to None. + initargs (tuple): The same param passed through to ThreadPoolExecutor. + Defaults to (). """ - self.max_total_retry = max_total_retry - self.ensure_tasks_pull_interval = ensure_tasks_pull_interval - self._start_lock, self._started = threading.Lock(), False self._total_task_num = 0 - self._finished_task_counter = itertools.count(start=1) - self._finished_task = 0 self._retry_counter = itertools.count(start=1) self._retry_count = 0 self._concurrent_semaphore = threading.Semaphore(max_concurrent) self._fut_queue: SimpleQueue[Future[Any]] = SimpleQueue() - # NOTE: leave two threads each for watchdog and dispatcher - max_workers = ( - max_workers + 2 if max_workers else min(32, (os.cpu_count() or 1) + 4) + super().__init__( + max_workers=max_workers, + thread_name_prefix=thread_name_prefix, + initializer=initializer, + initargs=initargs, ) - super().__init__(max_workers=max_workers, thread_name_prefix=thread_name_prefix) - - def _watchdog() -> None: - """Watchdog will shutdown the threadpool on certain conditions being met.""" - while not self._shutdown and not concurrent_fut_thread._shutdown: - if self.max_total_retry and self._retry_count > self.max_total_retry: - logger.warning(f"exceed {self.max_total_retry=}, abort") - return self.shutdown(wait=True) - if callable(watchdog_func): - try: - watchdog_func() - except Exception as e: - logger.warning(f"custom watchdog func failed: {e!r}, abort") - return self.shutdown(wait=True) - time.sleep(watchdog_check_interval) + if max_total_retry or callable(watchdog_func): + threading.Thread( + target=self._watchdog, + args=(max_total_retry, watchdog_func, watchdog_check_interval), + daemon=True, + ).start() - self.submit(_watchdog) + def _watchdog( + self, + max_retry: int | None, + watchdog_func: Callable[..., Any] | None, + interval: int, + ) -> None: + """Watchdog will shutdown the threadpool on certain conditions being met.""" + while not self._shutdown and not concurrent_fut_thread._shutdown: + if max_retry and self._retry_count > max_retry: + logger.warning(f"exceed {max_retry=}, abort") + return self.shutdown(wait=True) + + if callable(watchdog_func): + try: + watchdog_func() + except Exception as e: + logger.warning(f"custom watchdog func failed: {e!r}, abort") + return self.shutdown(wait=True) + time.sleep(interval) def _task_done_cb( self, fut: Future[Any], /, *, item: T, func: Callable[[T], Any] ) -> None: + self._concurrent_semaphore.release() # always release se first self._fut_queue.put_nowait(fut) - # ------ on task succeeded ------ # - if not fut.exception(): - self._concurrent_semaphore.release() - self._finished_task = next(self._finished_task_counter) - return + # ------ on task failed ------ # + if fut.exception(): + self._retry_count = next(self._retry_counter) + with contextlib.suppress(Exception): # on threadpool shutdown + self.submit(func, item).add_done_callback( + partial(self._task_done_cb, item=item, func=func) + ) - # ------ on threadpool shutdown(by watchdog) ------ # - if self._shutdown or self._broken: - # wakeup dispatcher - self._concurrent_semaphore.release() - return + def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]: + finished_tasks = 0 + while True: + if self._shutdown or self._broken or concurrent_fut_thread._shutdown: + logger.warning( + f"failed to ensure all tasks, {finished_tasks=}, {self._total_task_num=}" + ) + raise TasksEnsureFailed - # ------ on task failed ------ # - self._retry_count = next(self._retry_counter) - with contextlib.suppress(Exception): # on threadpool shutdown - self.submit(func, item).add_done_callback( - partial(self._task_done_cb, item=item, func=func) - ) + try: + done_fut = self._fut_queue.get_nowait() + if not done_fut.exception(): + finished_tasks += 1 + yield done_fut + except Empty: + if self._total_task_num == 0 or finished_tasks != self._total_task_num: + time.sleep(interval) + continue + return def ensure_tasks( - self, func: Callable[[T], RT], iterable: Iterable[T] + self, + func: Callable[[T], RT], + iterable: Iterable[T], + *, + ensure_tasks_pull_interval: int = 1, ) -> Generator[Future[RT], None, None]: """Ensure all the items in are processed by in the pool. @@ -138,9 +161,10 @@ def ensure_tasks( """ with self._start_lock: if self._started: - raise ValueError("ensure_tasks cannot be started more than once") - if self._shutdown or self._broken: - raise ValueError("threadpool is shutdown or broken, abort") + try: + raise ValueError("ensure_tasks cannot be started more than once") + finally: # do not hold refs to input params + del self, func, iterable self._started = True # ------ dispatch tasks from iterable ------ # @@ -160,23 +184,10 @@ def _dispatcher() -> None: self._total_task_num = _tasks_count logger.info(f"finish dispatch {_tasks_count} tasks") - self.submit(_dispatcher) + threading.Thread(target=_dispatcher, daemon=True).start() # ------ ensure all tasks are finished ------ # - while True: - if self._shutdown or self._broken or concurrent_fut_thread._shutdown: - logger.warning( - f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}" - ) - raise TasksEnsureFailed - - try: - yield self._fut_queue.get_nowait() - except Empty: - if ( - self._total_task_num == 0 - or self._finished_task != self._total_task_num - ): - time.sleep(self.ensure_tasks_pull_interval) - continue - return + # NOTE: also see base.Executor.map method, let the yield hidden in + # a generator so that the first fut will be dispatched before + # we start to get from fut_queue. + return self._fut_gen(ensure_tasks_pull_interval) diff --git a/tests/test_otaclient_common/test_retry_task_map.py b/tests/test_otaclient_common/test_retry_task_map.py index dd4d436e2..bd5a8d4bb 100644 --- a/tests/test_otaclient_common/test_retry_task_map.py +++ b/tests/test_otaclient_common/test_retry_task_map.py @@ -17,6 +17,7 @@ import logging import random +import threading import time import pytest @@ -25,39 +26,48 @@ logger = logging.getLogger(__name__) +# ------ test setup ------ # +WAIT_CONST = 100_000_000 +TASKS_COUNT = 2000 +MAX_CONCURRENT = 128 +MAX_WAIT_BEFORE_SUCCESS = 10 +THREAD_INIT_MSG = "thread init message" + class _RetryTaskMapTestErr(Exception): """""" +def _thread_initializer(msg: str) -> None: + """For testing thread worker initializer.""" + thread_native_id = threading.get_native_id() + logger.info(f"thread worker #{thread_native_id} initialized: {msg}") + + class TestRetryTaskMap: - WAIT_CONST = 100_000_000 - TASKS_COUNT = 2000 - MAX_CONCURRENT = 128 - MAX_WAIT_BEFORE_SUCCESS = 10 @pytest.fixture(autouse=True) def setup(self): self._start_time = time.time() self._success_wait_dict = { - idx: random.randint(0, self.MAX_WAIT_BEFORE_SUCCESS) - for idx in range(self.TASKS_COUNT) + idx: random.randint(0, MAX_WAIT_BEFORE_SUCCESS) + for idx in range(TASKS_COUNT) } - self._succeeded_tasks = [False for _ in range(self.TASKS_COUNT)] + self._succeeded_tasks = [False for _ in range(TASKS_COUNT)] def workload_aways_failed(self, idx: int) -> int: - time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST) + time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST) raise _RetryTaskMapTestErr def workload_failed_and_then_succeed(self, idx: int) -> int: - time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST) + time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST) if time.time() > self._start_time + self._success_wait_dict[idx]: self._succeeded_tasks[idx] = True return idx raise _RetryTaskMapTestErr def workload_succeed(self, idx: int) -> int: - time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST) + time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST) self._succeeded_tasks[idx] = True return idx @@ -69,48 +79,62 @@ def _exit_on_exceed_max_count(): raise ValueError(f"{failure_count=} > {MAX_RETRY=}") with retry_task_map.ThreadPoolExecutorWithRetry( - max_concurrent=self.MAX_CONCURRENT, + max_concurrent=MAX_CONCURRENT, watchdog_func=_exit_on_exceed_max_count, + initializer=_thread_initializer, + initargs=(THREAD_INIT_MSG,), ) as executor: with pytest.raises(retry_task_map.TasksEnsureFailed): for _fut in executor.ensure_tasks( - self.workload_aways_failed, range(self.TASKS_COUNT) + self.workload_aways_failed, range(TASKS_COUNT) ): if _fut.exception(): failure_count += 1 def test_retry_exceed_retry_limit(self): + MAX_TOTAL_RETRY = 200 + failure_count = 0 with retry_task_map.ThreadPoolExecutorWithRetry( - max_concurrent=self.MAX_CONCURRENT, max_total_retry=200 + max_concurrent=MAX_CONCURRENT, + max_total_retry=MAX_TOTAL_RETRY, + initializer=_thread_initializer, + initargs=(THREAD_INIT_MSG,), ) as executor: with pytest.raises(retry_task_map.TasksEnsureFailed): - for _ in executor.ensure_tasks( - self.workload_aways_failed, range(self.TASKS_COUNT) + for _fut in executor.ensure_tasks( + self.workload_aways_failed, range(TASKS_COUNT) ): - pass + if _fut.exception(): + failure_count += 1 + + assert failure_count >= MAX_TOTAL_RETRY def test_retry_finally_succeeded(self): count = 0 with retry_task_map.ThreadPoolExecutorWithRetry( - max_concurrent=self.MAX_CONCURRENT + max_concurrent=MAX_CONCURRENT, + initializer=_thread_initializer, + initargs=(THREAD_INIT_MSG,), ) as executor: for _fut in executor.ensure_tasks( - self.workload_failed_and_then_succeed, range(self.TASKS_COUNT) + self.workload_failed_and_then_succeed, range(TASKS_COUNT) ): if not _fut.exception(): count += 1 assert all(self._succeeded_tasks) - assert self.TASKS_COUNT == count + assert TASKS_COUNT == count def test_succeeded_in_one_try(self): count = 0 with retry_task_map.ThreadPoolExecutorWithRetry( - max_concurrent=self.MAX_CONCURRENT + max_concurrent=MAX_CONCURRENT, + initializer=_thread_initializer, + initargs=(THREAD_INIT_MSG,), ) as executor: for _fut in executor.ensure_tasks( - self.workload_succeed, range(self.TASKS_COUNT) + self.workload_succeed, range(TASKS_COUNT) ): if not _fut.exception(): count += 1 assert all(self._succeeded_tasks) - assert self.TASKS_COUNT == count + assert TASKS_COUNT == count