From fe82b0a14ca10592596d4220e0516c2ace177f54 Mon Sep 17 00:00:00 2001 From: "bodong.yang" Date: Sun, 23 Jun 2024 04:00:07 +0000 Subject: [PATCH 1/5] Squashed commit of the following: commit 35a6110460a70002af276ef589ef0abc40a19e3b Author: bodong.yang Date: Sat Jun 22 20:47:36 2024 +0000 minor update commit 7f93d3898222d888465707c8e3563ff66cc0ef22 Author: bodong.yang Date: Sat Jun 22 19:23:03 2024 +0000 try thread.Thread commit 1714c3f3937225067faeb2b71f2f0bb7a1a24aeb Author: bodong.yang Date: Sat Jun 22 19:21:44 2024 +0000 try tdcb add release se commit f083830c3634fe34dfa438698ba5949d53109042 Author: bodong.yang Date: Sat Jun 22 19:19:12 2024 +0000 add initfunc --- src/otaclient_common/retry_task_map.py | 33 +++++++++++++++----------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/otaclient_common/retry_task_map.py b/src/otaclient_common/retry_task_map.py index 5e0e964de..3d413e6ec 100644 --- a/src/otaclient_common/retry_task_map.py +++ b/src/otaclient_common/retry_task_map.py @@ -19,7 +19,6 @@ import contextlib import itertools import logging -import os import threading import time from concurrent.futures import Future, ThreadPoolExecutor @@ -47,6 +46,8 @@ def __init__( watchdog_func: Optional[Callable] = None, watchdog_check_interval: int = 3, # seconds ensure_tasks_pull_interval: int = 1, # second + initializer: Callable[..., Any] | None = None, + initargs: tuple = (), ) -> None: """Initialize a ThreadPoolExecutorWithRetry instance. @@ -72,11 +73,12 @@ def __init__( self._concurrent_semaphore = threading.Semaphore(max_concurrent) self._fut_queue: SimpleQueue[Future[Any]] = SimpleQueue() - # NOTE: leave two threads each for watchdog and dispatcher - max_workers = ( - max_workers + 2 if max_workers else min(32, (os.cpu_count() or 1) + 4) + super().__init__( + max_workers=max_workers, + thread_name_prefix=thread_name_prefix, + initializer=initializer, + initargs=initargs, ) - super().__init__(max_workers=max_workers, thread_name_prefix=thread_name_prefix) def _watchdog() -> None: """Watchdog will shutdown the threadpool on certain conditions being met.""" @@ -93,23 +95,22 @@ def _watchdog() -> None: return self.shutdown(wait=True) time.sleep(watchdog_check_interval) - self.submit(_watchdog) + if self.max_total_retry or callable(watchdog_func): + threading.Thread(target=_watchdog, daemon=True).start() def _task_done_cb( self, fut: Future[Any], /, *, item: T, func: Callable[[T], Any] ) -> None: self._fut_queue.put_nowait(fut) + self._concurrent_semaphore.release() # ------ on task succeeded ------ # if not fut.exception(): - self._concurrent_semaphore.release() self._finished_task = next(self._finished_task_counter) return # ------ on threadpool shutdown(by watchdog) ------ # if self._shutdown or self._broken: - # wakeup dispatcher - self._concurrent_semaphore.release() return # ------ on task failed ------ # @@ -138,9 +139,10 @@ def ensure_tasks( """ with self._start_lock: if self._started: - raise ValueError("ensure_tasks cannot be started more than once") - if self._shutdown or self._broken: - raise ValueError("threadpool is shutdown or broken, abort") + try: + raise ValueError("ensure_tasks cannot be started more than once") + finally: # do not hold refs to input params + del self, func, iterable self._started = True # ------ dispatch tasks from iterable ------ # @@ -160,7 +162,7 @@ def _dispatcher() -> None: self._total_task_num = _tasks_count logger.info(f"finish dispatch {_tasks_count} tasks") - self.submit(_dispatcher) + threading.Thread(target=_dispatcher, daemon=True).start() # ------ ensure all tasks are finished ------ # while True: @@ -168,7 +170,10 @@ def _dispatcher() -> None: logger.warning( f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}" ) - raise TasksEnsureFailed + try: + raise TasksEnsureFailed + finally: + del self, func, iterable try: yield self._fut_queue.get_nowait() From 32bae00cd20fff420fddbf91db75f8f851c790d0 Mon Sep 17 00:00:00 2001 From: "bodong.yang" Date: Sun, 23 Jun 2024 04:37:32 +0000 Subject: [PATCH 2/5] rtm: watchdog becomes a method, fut gen becomes a method --- src/otaclient_common/retry_task_map.py | 97 +++++++++++++++----------- 1 file changed, 56 insertions(+), 41 deletions(-) diff --git a/src/otaclient_common/retry_task_map.py b/src/otaclient_common/retry_task_map.py index 3d413e6ec..39f546f4d 100644 --- a/src/otaclient_common/retry_task_map.py +++ b/src/otaclient_common/retry_task_map.py @@ -45,7 +45,6 @@ def __init__( thread_name_prefix: str = "", watchdog_func: Optional[Callable] = None, watchdog_check_interval: int = 3, # seconds - ensure_tasks_pull_interval: int = 1, # second initializer: Callable[..., Any] | None = None, initargs: tuple = (), ) -> None: @@ -59,11 +58,11 @@ def __init__( watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None. watchdog_check_interval (int): Defaults to 3(seconds). - ensure_tasks_pull_interval (int): Defaults to 1(second). + initializer (Callable[..., Any] | None): The same param passed through to ThreadPoolExecutor. + Defaults to None. + initargs (tuple): The same param passed through to ThreadPoolExecutor. + Defaults to (). """ - self.max_total_retry = max_total_retry - self.ensure_tasks_pull_interval = ensure_tasks_pull_interval - self._start_lock, self._started = threading.Lock(), False self._total_task_num = 0 self._finished_task_counter = itertools.count(start=1) @@ -80,23 +79,32 @@ def __init__( initargs=initargs, ) - def _watchdog() -> None: - """Watchdog will shutdown the threadpool on certain conditions being met.""" - while not self._shutdown and not concurrent_fut_thread._shutdown: - if self.max_total_retry and self._retry_count > self.max_total_retry: - logger.warning(f"exceed {self.max_total_retry=}, abort") - return self.shutdown(wait=True) + if max_total_retry or callable(watchdog_func): + threading.Thread( + target=self._watchdog, + args=(max_total_retry, watchdog_func, watchdog_check_interval), + daemon=True, + ).start() - if callable(watchdog_func): - try: - watchdog_func() - except Exception as e: - logger.warning(f"custom watchdog func failed: {e!r}, abort") - return self.shutdown(wait=True) - time.sleep(watchdog_check_interval) + def _watchdog( + self, + max_retry: int | None, + watchdog_func: Callable[..., Any] | None, + interval: int, + ) -> None: + """Watchdog will shutdown the threadpool on certain conditions being met.""" + while not self._shutdown and not concurrent_fut_thread._shutdown: + if max_retry and self._retry_count > max_retry: + logger.warning(f"exceed {max_retry=}, abort") + return self.shutdown(wait=True) - if self.max_total_retry or callable(watchdog_func): - threading.Thread(target=_watchdog, daemon=True).start() + if callable(watchdog_func): + try: + watchdog_func() + except Exception as e: + logger.warning(f"custom watchdog func failed: {e!r}, abort") + return self.shutdown(wait=True) + time.sleep(interval) def _task_done_cb( self, fut: Future[Any], /, *, item: T, func: Callable[[T], Any] @@ -120,8 +128,31 @@ def _task_done_cb( partial(self._task_done_cb, item=item, func=func) ) + def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]: + while True: + if self._shutdown or self._broken or concurrent_fut_thread._shutdown: + logger.warning( + f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}" + ) + raise TasksEnsureFailed + + try: + yield self._fut_queue.get_nowait() + except Empty: + if ( + self._total_task_num == 0 + or self._finished_task != self._total_task_num + ): + time.sleep(interval) + continue + return + def ensure_tasks( - self, func: Callable[[T], RT], iterable: Iterable[T] + self, + func: Callable[[T], RT], + iterable: Iterable[T], + *, + ensure_tasks_pull_interval: int = 1, ) -> Generator[Future[RT], None, None]: """Ensure all the items in are processed by in the pool. @@ -165,23 +196,7 @@ def _dispatcher() -> None: threading.Thread(target=_dispatcher, daemon=True).start() # ------ ensure all tasks are finished ------ # - while True: - if self._shutdown or self._broken or concurrent_fut_thread._shutdown: - logger.warning( - f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}" - ) - try: - raise TasksEnsureFailed - finally: - del self, func, iterable - - try: - yield self._fut_queue.get_nowait() - except Empty: - if ( - self._total_task_num == 0 - or self._finished_task != self._total_task_num - ): - time.sleep(self.ensure_tasks_pull_interval) - continue - return + # NOTE: also see base.Executor.map method, let the yield hidden in + # a generator so that the first fut will be dispatched before + # we start to get from fut_queue. + return self._fut_gen(ensure_tasks_pull_interval) From fd4f1af36b3b27597a7aef953dc5da3ecfd5df2b Mon Sep 17 00:00:00 2001 From: "bodong.yang" Date: Mon, 24 Jun 2024 02:08:23 +0000 Subject: [PATCH 3/5] always release se first --- src/otaclient_common/retry_task_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/otaclient_common/retry_task_map.py b/src/otaclient_common/retry_task_map.py index 39f546f4d..221fc9521 100644 --- a/src/otaclient_common/retry_task_map.py +++ b/src/otaclient_common/retry_task_map.py @@ -109,8 +109,8 @@ def _watchdog( def _task_done_cb( self, fut: Future[Any], /, *, item: T, func: Callable[[T], Any] ) -> None: + self._concurrent_semaphore.release() # always release se first self._fut_queue.put_nowait(fut) - self._concurrent_semaphore.release() # ------ on task succeeded ------ # if not fut.exception(): From 4da2b8e49483c6daf3185494e7c46f2ad4aad98a Mon Sep 17 00:00:00 2001 From: "bodong.yang" Date: Mon, 24 Jun 2024 02:15:06 +0000 Subject: [PATCH 4/5] test_retry_task_map: tests now cover thread worker initializer --- .../test_retry_task_map.py | 68 +++++++++++++------ 1 file changed, 46 insertions(+), 22 deletions(-) diff --git a/tests/test_otaclient_common/test_retry_task_map.py b/tests/test_otaclient_common/test_retry_task_map.py index dd4d436e2..bd5a8d4bb 100644 --- a/tests/test_otaclient_common/test_retry_task_map.py +++ b/tests/test_otaclient_common/test_retry_task_map.py @@ -17,6 +17,7 @@ import logging import random +import threading import time import pytest @@ -25,39 +26,48 @@ logger = logging.getLogger(__name__) +# ------ test setup ------ # +WAIT_CONST = 100_000_000 +TASKS_COUNT = 2000 +MAX_CONCURRENT = 128 +MAX_WAIT_BEFORE_SUCCESS = 10 +THREAD_INIT_MSG = "thread init message" + class _RetryTaskMapTestErr(Exception): """""" +def _thread_initializer(msg: str) -> None: + """For testing thread worker initializer.""" + thread_native_id = threading.get_native_id() + logger.info(f"thread worker #{thread_native_id} initialized: {msg}") + + class TestRetryTaskMap: - WAIT_CONST = 100_000_000 - TASKS_COUNT = 2000 - MAX_CONCURRENT = 128 - MAX_WAIT_BEFORE_SUCCESS = 10 @pytest.fixture(autouse=True) def setup(self): self._start_time = time.time() self._success_wait_dict = { - idx: random.randint(0, self.MAX_WAIT_BEFORE_SUCCESS) - for idx in range(self.TASKS_COUNT) + idx: random.randint(0, MAX_WAIT_BEFORE_SUCCESS) + for idx in range(TASKS_COUNT) } - self._succeeded_tasks = [False for _ in range(self.TASKS_COUNT)] + self._succeeded_tasks = [False for _ in range(TASKS_COUNT)] def workload_aways_failed(self, idx: int) -> int: - time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST) + time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST) raise _RetryTaskMapTestErr def workload_failed_and_then_succeed(self, idx: int) -> int: - time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST) + time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST) if time.time() > self._start_time + self._success_wait_dict[idx]: self._succeeded_tasks[idx] = True return idx raise _RetryTaskMapTestErr def workload_succeed(self, idx: int) -> int: - time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST) + time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST) self._succeeded_tasks[idx] = True return idx @@ -69,48 +79,62 @@ def _exit_on_exceed_max_count(): raise ValueError(f"{failure_count=} > {MAX_RETRY=}") with retry_task_map.ThreadPoolExecutorWithRetry( - max_concurrent=self.MAX_CONCURRENT, + max_concurrent=MAX_CONCURRENT, watchdog_func=_exit_on_exceed_max_count, + initializer=_thread_initializer, + initargs=(THREAD_INIT_MSG,), ) as executor: with pytest.raises(retry_task_map.TasksEnsureFailed): for _fut in executor.ensure_tasks( - self.workload_aways_failed, range(self.TASKS_COUNT) + self.workload_aways_failed, range(TASKS_COUNT) ): if _fut.exception(): failure_count += 1 def test_retry_exceed_retry_limit(self): + MAX_TOTAL_RETRY = 200 + failure_count = 0 with retry_task_map.ThreadPoolExecutorWithRetry( - max_concurrent=self.MAX_CONCURRENT, max_total_retry=200 + max_concurrent=MAX_CONCURRENT, + max_total_retry=MAX_TOTAL_RETRY, + initializer=_thread_initializer, + initargs=(THREAD_INIT_MSG,), ) as executor: with pytest.raises(retry_task_map.TasksEnsureFailed): - for _ in executor.ensure_tasks( - self.workload_aways_failed, range(self.TASKS_COUNT) + for _fut in executor.ensure_tasks( + self.workload_aways_failed, range(TASKS_COUNT) ): - pass + if _fut.exception(): + failure_count += 1 + + assert failure_count >= MAX_TOTAL_RETRY def test_retry_finally_succeeded(self): count = 0 with retry_task_map.ThreadPoolExecutorWithRetry( - max_concurrent=self.MAX_CONCURRENT + max_concurrent=MAX_CONCURRENT, + initializer=_thread_initializer, + initargs=(THREAD_INIT_MSG,), ) as executor: for _fut in executor.ensure_tasks( - self.workload_failed_and_then_succeed, range(self.TASKS_COUNT) + self.workload_failed_and_then_succeed, range(TASKS_COUNT) ): if not _fut.exception(): count += 1 assert all(self._succeeded_tasks) - assert self.TASKS_COUNT == count + assert TASKS_COUNT == count def test_succeeded_in_one_try(self): count = 0 with retry_task_map.ThreadPoolExecutorWithRetry( - max_concurrent=self.MAX_CONCURRENT + max_concurrent=MAX_CONCURRENT, + initializer=_thread_initializer, + initargs=(THREAD_INIT_MSG,), ) as executor: for _fut in executor.ensure_tasks( - self.workload_succeed, range(self.TASKS_COUNT) + self.workload_succeed, range(TASKS_COUNT) ): if not _fut.exception(): count += 1 assert all(self._succeeded_tasks) - assert self.TASKS_COUNT == count + assert TASKS_COUNT == count From 8572c09f87c3728507959437206716c6d30d4085 Mon Sep 17 00:00:00 2001 From: "bodong.yang" Date: Mon, 24 Jun 2024 04:43:33 +0000 Subject: [PATCH 5/5] finished_tasks become local var for fut_gen method --- src/otaclient_common/retry_task_map.py | 35 ++++++++++---------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/src/otaclient_common/retry_task_map.py b/src/otaclient_common/retry_task_map.py index 221fc9521..055cb4bbf 100644 --- a/src/otaclient_common/retry_task_map.py +++ b/src/otaclient_common/retry_task_map.py @@ -65,8 +65,6 @@ def __init__( """ self._start_lock, self._started = threading.Lock(), False self._total_task_num = 0 - self._finished_task_counter = itertools.count(start=1) - self._finished_task = 0 self._retry_counter = itertools.count(start=1) self._retry_count = 0 self._concurrent_semaphore = threading.Semaphore(max_concurrent) @@ -112,37 +110,30 @@ def _task_done_cb( self._concurrent_semaphore.release() # always release se first self._fut_queue.put_nowait(fut) - # ------ on task succeeded ------ # - if not fut.exception(): - self._finished_task = next(self._finished_task_counter) - return - - # ------ on threadpool shutdown(by watchdog) ------ # - if self._shutdown or self._broken: - return - # ------ on task failed ------ # - self._retry_count = next(self._retry_counter) - with contextlib.suppress(Exception): # on threadpool shutdown - self.submit(func, item).add_done_callback( - partial(self._task_done_cb, item=item, func=func) - ) + if fut.exception(): + self._retry_count = next(self._retry_counter) + with contextlib.suppress(Exception): # on threadpool shutdown + self.submit(func, item).add_done_callback( + partial(self._task_done_cb, item=item, func=func) + ) def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]: + finished_tasks = 0 while True: if self._shutdown or self._broken or concurrent_fut_thread._shutdown: logger.warning( - f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}" + f"failed to ensure all tasks, {finished_tasks=}, {self._total_task_num=}" ) raise TasksEnsureFailed try: - yield self._fut_queue.get_nowait() + done_fut = self._fut_queue.get_nowait() + if not done_fut.exception(): + finished_tasks += 1 + yield done_fut except Empty: - if ( - self._total_task_num == 0 - or self._finished_task != self._total_task_num - ): + if self._total_task_num == 0 or finished_tasks != self._total_task_num: time.sleep(interval) continue return