Skip to content

Commit

Permalink
refactor: retry_task_map now takes initializer and initargs params (#324
Browse files Browse the repository at this point in the history
)

This PR allows retry_task_map.ThreadPoolExecutorWithRetry to take initializer and initargs params, it will passthrough these params down to the ThreadPoolExecutor.__init__ to thread worker initializing. Now the ThreadPoolExecutorWithRetry can take all params taken by ThreadPoolExecutor.

Other changes:
* internal refinement over retry_task_map, now the watchdog and dispatcher don't occupy the thread pool workers anymore, there are running by dedicated threads.
* internal refinement, now fut_gen(executed by main thread) will count the finished tasks.
* do not launch watchdog thread if no checks will be performed.
  • Loading branch information
Bodong-Yang authored Jun 24, 2024
1 parent 77bde15 commit d156096
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 87 deletions.
141 changes: 76 additions & 65 deletions src/otaclient_common/retry_task_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import contextlib
import itertools
import logging
import os
import threading
import time
from concurrent.futures import Future, ThreadPoolExecutor
Expand All @@ -46,7 +45,8 @@ def __init__(
thread_name_prefix: str = "",
watchdog_func: Optional[Callable] = None,
watchdog_check_interval: int = 3, # seconds
ensure_tasks_pull_interval: int = 1, # second
initializer: Callable[..., Any] | None = None,
initargs: tuple = (),
) -> None:
"""Initialize a ThreadPoolExecutorWithRetry instance.
Expand All @@ -58,69 +58,92 @@ def __init__(
watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when
this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None.
watchdog_check_interval (int): Defaults to 3(seconds).
ensure_tasks_pull_interval (int): Defaults to 1(second).
initializer (Callable[..., Any] | None): The same <initializer> param passed through to ThreadPoolExecutor.
Defaults to None.
initargs (tuple): The same <initargs> param passed through to ThreadPoolExecutor.
Defaults to ().
"""
self.max_total_retry = max_total_retry
self.ensure_tasks_pull_interval = ensure_tasks_pull_interval

self._start_lock, self._started = threading.Lock(), False
self._total_task_num = 0
self._finished_task_counter = itertools.count(start=1)
self._finished_task = 0
self._retry_counter = itertools.count(start=1)
self._retry_count = 0
self._concurrent_semaphore = threading.Semaphore(max_concurrent)
self._fut_queue: SimpleQueue[Future[Any]] = SimpleQueue()

# NOTE: leave two threads each for watchdog and dispatcher
max_workers = (
max_workers + 2 if max_workers else min(32, (os.cpu_count() or 1) + 4)
super().__init__(
max_workers=max_workers,
thread_name_prefix=thread_name_prefix,
initializer=initializer,
initargs=initargs,
)
super().__init__(max_workers=max_workers, thread_name_prefix=thread_name_prefix)

def _watchdog() -> None:
"""Watchdog will shutdown the threadpool on certain conditions being met."""
while not self._shutdown and not concurrent_fut_thread._shutdown:
if self.max_total_retry and self._retry_count > self.max_total_retry:
logger.warning(f"exceed {self.max_total_retry=}, abort")
return self.shutdown(wait=True)

if callable(watchdog_func):
try:
watchdog_func()
except Exception as e:
logger.warning(f"custom watchdog func failed: {e!r}, abort")
return self.shutdown(wait=True)
time.sleep(watchdog_check_interval)
if max_total_retry or callable(watchdog_func):
threading.Thread(
target=self._watchdog,
args=(max_total_retry, watchdog_func, watchdog_check_interval),
daemon=True,
).start()

self.submit(_watchdog)
def _watchdog(
self,
max_retry: int | None,
watchdog_func: Callable[..., Any] | None,
interval: int,
) -> None:
"""Watchdog will shutdown the threadpool on certain conditions being met."""
while not self._shutdown and not concurrent_fut_thread._shutdown:
if max_retry and self._retry_count > max_retry:
logger.warning(f"exceed {max_retry=}, abort")
return self.shutdown(wait=True)

if callable(watchdog_func):
try:
watchdog_func()
except Exception as e:
logger.warning(f"custom watchdog func failed: {e!r}, abort")
return self.shutdown(wait=True)
time.sleep(interval)

def _task_done_cb(
self, fut: Future[Any], /, *, item: T, func: Callable[[T], Any]
) -> None:
self._concurrent_semaphore.release() # always release se first
self._fut_queue.put_nowait(fut)

# ------ on task succeeded ------ #
if not fut.exception():
self._concurrent_semaphore.release()
self._finished_task = next(self._finished_task_counter)
return
# ------ on task failed ------ #
if fut.exception():
self._retry_count = next(self._retry_counter)
with contextlib.suppress(Exception): # on threadpool shutdown
self.submit(func, item).add_done_callback(
partial(self._task_done_cb, item=item, func=func)
)

# ------ on threadpool shutdown(by watchdog) ------ #
if self._shutdown or self._broken:
# wakeup dispatcher
self._concurrent_semaphore.release()
return
def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]:
finished_tasks = 0
while True:
if self._shutdown or self._broken or concurrent_fut_thread._shutdown:
logger.warning(
f"failed to ensure all tasks, {finished_tasks=}, {self._total_task_num=}"
)
raise TasksEnsureFailed

# ------ on task failed ------ #
self._retry_count = next(self._retry_counter)
with contextlib.suppress(Exception): # on threadpool shutdown
self.submit(func, item).add_done_callback(
partial(self._task_done_cb, item=item, func=func)
)
try:
done_fut = self._fut_queue.get_nowait()
if not done_fut.exception():
finished_tasks += 1
yield done_fut
except Empty:
if self._total_task_num == 0 or finished_tasks != self._total_task_num:
time.sleep(interval)
continue
return

def ensure_tasks(
self, func: Callable[[T], RT], iterable: Iterable[T]
self,
func: Callable[[T], RT],
iterable: Iterable[T],
*,
ensure_tasks_pull_interval: int = 1,
) -> Generator[Future[RT], None, None]:
"""Ensure all the items in <iterable> are processed by <func> in the pool.
Expand All @@ -138,9 +161,10 @@ def ensure_tasks(
"""
with self._start_lock:
if self._started:
raise ValueError("ensure_tasks cannot be started more than once")
if self._shutdown or self._broken:
raise ValueError("threadpool is shutdown or broken, abort")
try:
raise ValueError("ensure_tasks cannot be started more than once")
finally: # do not hold refs to input params
del self, func, iterable
self._started = True

# ------ dispatch tasks from iterable ------ #
Expand All @@ -160,23 +184,10 @@ def _dispatcher() -> None:
self._total_task_num = _tasks_count
logger.info(f"finish dispatch {_tasks_count} tasks")

self.submit(_dispatcher)
threading.Thread(target=_dispatcher, daemon=True).start()

# ------ ensure all tasks are finished ------ #
while True:
if self._shutdown or self._broken or concurrent_fut_thread._shutdown:
logger.warning(
f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}"
)
raise TasksEnsureFailed

try:
yield self._fut_queue.get_nowait()
except Empty:
if (
self._total_task_num == 0
or self._finished_task != self._total_task_num
):
time.sleep(self.ensure_tasks_pull_interval)
continue
return
# NOTE: also see base.Executor.map method, let the yield hidden in
# a generator so that the first fut will be dispatched before
# we start to get from fut_queue.
return self._fut_gen(ensure_tasks_pull_interval)
68 changes: 46 additions & 22 deletions tests/test_otaclient_common/test_retry_task_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import logging
import random
import threading
import time

import pytest
Expand All @@ -25,39 +26,48 @@

logger = logging.getLogger(__name__)

# ------ test setup ------ #
WAIT_CONST = 100_000_000
TASKS_COUNT = 2000
MAX_CONCURRENT = 128
MAX_WAIT_BEFORE_SUCCESS = 10
THREAD_INIT_MSG = "thread init message"


class _RetryTaskMapTestErr(Exception):
""""""


def _thread_initializer(msg: str) -> None:
"""For testing thread worker initializer."""
thread_native_id = threading.get_native_id()
logger.info(f"thread worker #{thread_native_id} initialized: {msg}")


class TestRetryTaskMap:
WAIT_CONST = 100_000_000
TASKS_COUNT = 2000
MAX_CONCURRENT = 128
MAX_WAIT_BEFORE_SUCCESS = 10

@pytest.fixture(autouse=True)
def setup(self):
self._start_time = time.time()
self._success_wait_dict = {
idx: random.randint(0, self.MAX_WAIT_BEFORE_SUCCESS)
for idx in range(self.TASKS_COUNT)
idx: random.randint(0, MAX_WAIT_BEFORE_SUCCESS)
for idx in range(TASKS_COUNT)
}
self._succeeded_tasks = [False for _ in range(self.TASKS_COUNT)]
self._succeeded_tasks = [False for _ in range(TASKS_COUNT)]

def workload_aways_failed(self, idx: int) -> int:
time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST)
time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST)
raise _RetryTaskMapTestErr

def workload_failed_and_then_succeed(self, idx: int) -> int:
time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST)
time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST)
if time.time() > self._start_time + self._success_wait_dict[idx]:
self._succeeded_tasks[idx] = True
return idx
raise _RetryTaskMapTestErr

def workload_succeed(self, idx: int) -> int:
time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST)
time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST)
self._succeeded_tasks[idx] = True
return idx

Expand All @@ -69,48 +79,62 @@ def _exit_on_exceed_max_count():
raise ValueError(f"{failure_count=} > {MAX_RETRY=}")

with retry_task_map.ThreadPoolExecutorWithRetry(
max_concurrent=self.MAX_CONCURRENT,
max_concurrent=MAX_CONCURRENT,
watchdog_func=_exit_on_exceed_max_count,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
) as executor:
with pytest.raises(retry_task_map.TasksEnsureFailed):
for _fut in executor.ensure_tasks(
self.workload_aways_failed, range(self.TASKS_COUNT)
self.workload_aways_failed, range(TASKS_COUNT)
):
if _fut.exception():
failure_count += 1

def test_retry_exceed_retry_limit(self):
MAX_TOTAL_RETRY = 200
failure_count = 0
with retry_task_map.ThreadPoolExecutorWithRetry(
max_concurrent=self.MAX_CONCURRENT, max_total_retry=200
max_concurrent=MAX_CONCURRENT,
max_total_retry=MAX_TOTAL_RETRY,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
) as executor:
with pytest.raises(retry_task_map.TasksEnsureFailed):
for _ in executor.ensure_tasks(
self.workload_aways_failed, range(self.TASKS_COUNT)
for _fut in executor.ensure_tasks(
self.workload_aways_failed, range(TASKS_COUNT)
):
pass
if _fut.exception():
failure_count += 1

assert failure_count >= MAX_TOTAL_RETRY

def test_retry_finally_succeeded(self):
count = 0
with retry_task_map.ThreadPoolExecutorWithRetry(
max_concurrent=self.MAX_CONCURRENT
max_concurrent=MAX_CONCURRENT,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
) as executor:
for _fut in executor.ensure_tasks(
self.workload_failed_and_then_succeed, range(self.TASKS_COUNT)
self.workload_failed_and_then_succeed, range(TASKS_COUNT)
):
if not _fut.exception():
count += 1
assert all(self._succeeded_tasks)
assert self.TASKS_COUNT == count
assert TASKS_COUNT == count

def test_succeeded_in_one_try(self):
count = 0
with retry_task_map.ThreadPoolExecutorWithRetry(
max_concurrent=self.MAX_CONCURRENT
max_concurrent=MAX_CONCURRENT,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
) as executor:
for _fut in executor.ensure_tasks(
self.workload_succeed, range(self.TASKS_COUNT)
self.workload_succeed, range(TASKS_COUNT)
):
if not _fut.exception():
count += 1
assert all(self._succeeded_tasks)
assert self.TASKS_COUNT == count
assert TASKS_COUNT == count

1 comment on commit d156096

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/ota_metadata/legacy
   __init__.py110100% 
   parser.py3263589%145, 150, 186–187, 197–198, 201, 213, 271, 281–284, 323–326, 406, 409, 417–419, 432, 441–442, 445–446, 720, 723, 735–740
   types.py841384%37, 40–42, 112–116, 122–125
src/ota_proxy
   __init__.py361072%59, 61, 63, 72, 81–82, 102, 104–106
   __main__.py770%16–18, 20, 22–23, 25
   _consts.py150100% 
   cache_control.py68494%71, 91, 113, 121
   config.py180100% 
   db.py1461589%75, 81, 103, 113, 116, 145–147, 166, 199, 208–209, 229, 258, 300
   errors.py50100% 
   orm.py1121091%92, 97, 102, 108, 114, 141–142, 155, 232, 236
   ota_cache.py4018678%98–99, 218, 229, 256–258, 278, 294–295, 297, 320–321, 327, 331, 333, 360–362, 378, 439–440, 482–483, 553, 566–569, 619, 638–639, 671–672, 683, 717–721, 725–727, 729, 731–738, 740–742, 745–746, 750–751, 755, 802, 810–812, 891–894, 898, 901–902, 916–917, 919–921, 925–926, 932–933, 964, 970, 997, 1026–1028
   server_app.py1383971%76, 79, 85, 101, 103, 162, 171, 213–214, 216–218, 221, 226–228, 231–232, 235, 238, 241, 244, 257–258, 261–262, 264, 267, 293–296, 299, 313–315, 321–323
   utils.py23195%33
src/otaclient
   __init__.py5260%17, 19
   __main__.py110%16
   log_setting.py52590%53, 55, 64–66
src/otaclient/app
   __main__.py110%16
   configs.py750100% 
   errors.py1120100% 
   interface.py50100% 
   main.py46589%52–53, 75–77
   ota_client.py39213166%67, 75, 96, 201–203, 214, 260–263, 275–278, 281–284, 294–297, 302–303, 305, 314, 317, 322–323, 326, 332, 334, 337, 379–382, 387, 391, 394, 410–413, 416–423, 426–433, 439–442, 471, 474–475, 477, 480–483, 485–486, 491–492, 495, 509–516, 523, 526–532, 579–582, 590, 626, 631–634, 639–641, 644–645, 647–648, 650–651, 653, 713–714, 717, 725–726, 729, 740–741, 744, 752–753, 756, 767, 786, 813, 832, 850
   ota_client_stub.py39410972%76–78, 80–81, 89–92, 95–97, 101, 106–107, 109–110, 113, 115–116, 119–121, 124–125, 128–130, 135–140, 144, 147–151, 153–154, 162–164, 167, 204–206, 211, 247, 272, 275, 278, 382, 406, 408, 432, 478, 535, 605–606, 645, 664–666, 672–675, 679–681, 688–690, 693, 697–700, 753, 842–844, 851, 881–882, 885–889, 898–907, 914, 920, 923–924, 928, 931
   update_stats.py106298%162, 172
src/otaclient/app/boot_control
   __init__.py40100% 
   _common.py24811254%74–75, 96–98, 114–115, 135–136, 155–156, 175–176, 195–196, 218–220, 235–236, 260–266, 287, 295, 313, 321, 340–341, 344–345, 368, 370–379, 381–390, 392–394, 413, 416, 424, 432, 448–450, 452–457, 550, 555, 560, 673, 677–678, 681, 689, 691–692, 718–719, 721–724, 729, 735–736, 739–740, 742, 749–750, 761–767, 775–777, 781–782, 785–786, 789, 795
   _grub.py41712869%216, 264–267, 273–277, 314–315, 322–327, 330–336, 339, 342–343, 348, 350–352, 361–367, 369–370, 372–374, 383–385, 387–389, 468–469, 473–474, 526, 532, 558, 580, 584–585, 600–602, 626–629, 641, 645–647, 649–651, 710–713, 738–741, 764–767, 779–780, 783–784, 819, 825, 845–846, 848, 860, 863, 866, 869, 873–875, 893–896, 924–927, 932–940, 945–953
   _jetson_cboot.py27021420%69–70, 77–78, 96–105, 117, 124–125, 137, 143–144, 154–156, 168–169, 180–181, 184–185, 188–189, 192–196, 199–200, 204–205, 210–211, 213–217, 219–225, 227–228, 233, 236, 239–240, 243, 247–248, 252–253, 257, 260, 263, 267–273, 275–277, 282, 285, 288, 292, 299, 301–304, 317, 320, 324, 326–328, 332, 339, 341, 344, 350–351, 356, 364, 372–374, 383–384, 386–388, 394, 397–399, 403–404, 406, 409, 418–420, 423, 426, 429–434, 436–438, 441, 444, 448–453, 457–459, 464–465, 469–470, 473, 476, 479–480, 483, 486, 491, 494, 497–498, 500, 502, 505, 508, 510–511, 514–518, 523–524, 526, 534–538, 540, 543, 546, 557–558, 563, 573, 576–584, 589–597, 602–610, 616–618, 621, 624
   _jetson_common.py1416653%50, 74, 129–134, 136, 141–143, 148–151, 159–160, 167–168, 173–174, 190–191, 193–195, 198–200, 203, 207, 211, 215–217, 223–224, 226, 259, 285–286, 288–290, 294–297, 299–300, 302–306, 308, 315–316, 319, 321, 331, 334–335, 338, 340
   _rpi_boot.py28613453%54, 57, 121–122, 126, 134–137, 151–154, 161–162, 164–165, 170–171, 174–175, 184–185, 223, 229–233, 236, 254–256, 260–262, 267–269, 273–275, 285–286, 289, 292, 294–295, 297–298, 300–302, 308, 311–312, 322–325, 333–337, 339, 341–342, 347–348, 355–361, 392, 394–397, 407–410, 414–415, 417–421, 449–452, 471–474, 479, 482, 500–503, 508–516, 521–529, 546–549, 555–557, 560, 563
   configs.py380100% 
   protocol.py40100% 
   selecter.py382631%44–46, 49–50, 54–55, 58–60, 63, 65, 69, 77–79, 81–82, 84–85, 89, 91–93, 95, 97
src/otaclient/app/create_standby
   __init__.py12558%28–30, 32, 34
   common.py2194380%63, 66–67, 71–73, 75, 79–80, 82, 128, 176–178, 180–182, 184, 187–190, 194, 205, 279–280, 282–287, 300, 355, 358–360, 376–377, 391, 395, 417–418
   interface.py50100% 
   rebuild_mode.py99990%94–96, 108–113
src/otaclient/configs
   _common.py80100% 
   ecu_info.py57198%107
   proxy_info.py52296%88, 90
src/otaclient_api/v2
   __init__.py140100% 
   api_caller.py39684%45–47, 83–85
   api_stub.py170100% 
   types.py2562391%86, 89–92, 131, 209–210, 212, 259, 262–263, 506–508, 512–513, 515, 518–519, 522–523, 586
src/otaclient_common
   __init__.py34876%42–44, 59, 61, 67, 74–75
   common.py1541987%41, 45, 200, 203–205, 220, 227–229, 295–297, 307, 316–318, 364, 368
   downloader.py2694384%72, 85–86, 301, 306, 328–329, 379–383, 402–404, 407–408, 411–412, 433–436, 440–441, 445–446, 450–451, 460, 535–537, 553, 573–575, 579, 581, 584, 589–591
   linux.py611575%51–53, 59, 69, 74, 76, 108–109, 133–134, 190, 195–196, 198
   logging.py29196%55
   persist_file_handling.py1131884%112, 114, 146–148, 150, 176–179, 184, 188–192, 218–219
   proto_streamer.py42880%33, 48, 66–67, 72, 81–82, 100
   proto_wrapper.py3984588%87, 165, 172, 184–186, 205, 210, 221, 257, 263, 268, 299, 303, 307, 402, 462, 469, 472, 492, 499, 501, 526, 532, 535, 537, 562, 568, 571, 573, 605, 609, 611, 625, 642, 669, 672, 676, 707, 713, 760–763, 765
   retry_task_map.py80791%164–165, 167, 179–182
   typing.py250100% 
TOTAL6008140976% 

Tests Skipped Failures Errors Time
179 0 💤 0 ❌ 0 🔥 5m 2s ⏱️

Please sign in to comment.