Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: retry_task_map now takes initializer and initargs params #324

Merged
merged 6 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 76 additions & 65 deletions src/otaclient_common/retry_task_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import contextlib
import itertools
import logging
import os
import threading
import time
from concurrent.futures import Future, ThreadPoolExecutor
Expand All @@ -46,7 +45,8 @@ def __init__(
thread_name_prefix: str = "",
watchdog_func: Optional[Callable] = None,
watchdog_check_interval: int = 3, # seconds
ensure_tasks_pull_interval: int = 1, # second
initializer: Callable[..., Any] | None = None,
initargs: tuple = (),
) -> None:
"""Initialize a ThreadPoolExecutorWithRetry instance.

Expand All @@ -58,69 +58,92 @@ def __init__(
watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when
this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None.
watchdog_check_interval (int): Defaults to 3(seconds).
ensure_tasks_pull_interval (int): Defaults to 1(second).
initializer (Callable[..., Any] | None): The same <initializer> param passed through to ThreadPoolExecutor.
Defaults to None.
initargs (tuple): The same <initargs> param passed through to ThreadPoolExecutor.
Defaults to ().
"""
self.max_total_retry = max_total_retry
self.ensure_tasks_pull_interval = ensure_tasks_pull_interval

self._start_lock, self._started = threading.Lock(), False
self._total_task_num = 0
self._finished_task_counter = itertools.count(start=1)
self._finished_task = 0
self._retry_counter = itertools.count(start=1)
self._retry_count = 0
self._concurrent_semaphore = threading.Semaphore(max_concurrent)
self._fut_queue: SimpleQueue[Future[Any]] = SimpleQueue()

# NOTE: leave two threads each for watchdog and dispatcher
max_workers = (
max_workers + 2 if max_workers else min(32, (os.cpu_count() or 1) + 4)
super().__init__(
max_workers=max_workers,
thread_name_prefix=thread_name_prefix,
initializer=initializer,
initargs=initargs,
)
super().__init__(max_workers=max_workers, thread_name_prefix=thread_name_prefix)

def _watchdog() -> None:
"""Watchdog will shutdown the threadpool on certain conditions being met."""
while not self._shutdown and not concurrent_fut_thread._shutdown:
if self.max_total_retry and self._retry_count > self.max_total_retry:
logger.warning(f"exceed {self.max_total_retry=}, abort")
return self.shutdown(wait=True)

if callable(watchdog_func):
try:
watchdog_func()
except Exception as e:
logger.warning(f"custom watchdog func failed: {e!r}, abort")
return self.shutdown(wait=True)
time.sleep(watchdog_check_interval)
if max_total_retry or callable(watchdog_func):
threading.Thread(
target=self._watchdog,
args=(max_total_retry, watchdog_func, watchdog_check_interval),
daemon=True,
).start()

self.submit(_watchdog)
def _watchdog(
self,
max_retry: int | None,
watchdog_func: Callable[..., Any] | None,
interval: int,
) -> None:
"""Watchdog will shutdown the threadpool on certain conditions being met."""
while not self._shutdown and not concurrent_fut_thread._shutdown:
if max_retry and self._retry_count > max_retry:
logger.warning(f"exceed {max_retry=}, abort")
return self.shutdown(wait=True)

if callable(watchdog_func):
try:
watchdog_func()
except Exception as e:
logger.warning(f"custom watchdog func failed: {e!r}, abort")
return self.shutdown(wait=True)
time.sleep(interval)

def _task_done_cb(
self, fut: Future[Any], /, *, item: T, func: Callable[[T], Any]
) -> None:
self._concurrent_semaphore.release() # always release se first
self._fut_queue.put_nowait(fut)

# ------ on task succeeded ------ #
if not fut.exception():
self._concurrent_semaphore.release()
self._finished_task = next(self._finished_task_counter)
return
# ------ on task failed ------ #
if fut.exception():
self._retry_count = next(self._retry_counter)
with contextlib.suppress(Exception): # on threadpool shutdown
self.submit(func, item).add_done_callback(
partial(self._task_done_cb, item=item, func=func)
)

# ------ on threadpool shutdown(by watchdog) ------ #
if self._shutdown or self._broken:
# wakeup dispatcher
self._concurrent_semaphore.release()
return
def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]:
finished_tasks = 0
while True:
if self._shutdown or self._broken or concurrent_fut_thread._shutdown:
logger.warning(
f"failed to ensure all tasks, {finished_tasks=}, {self._total_task_num=}"
)
raise TasksEnsureFailed

# ------ on task failed ------ #
self._retry_count = next(self._retry_counter)
with contextlib.suppress(Exception): # on threadpool shutdown
self.submit(func, item).add_done_callback(
partial(self._task_done_cb, item=item, func=func)
)
try:
done_fut = self._fut_queue.get_nowait()
if not done_fut.exception():
finished_tasks += 1
yield done_fut
except Empty:
if self._total_task_num == 0 or finished_tasks != self._total_task_num:
time.sleep(interval)
continue
return

def ensure_tasks(
self, func: Callable[[T], RT], iterable: Iterable[T]
self,
func: Callable[[T], RT],
iterable: Iterable[T],
*,
ensure_tasks_pull_interval: int = 1,
) -> Generator[Future[RT], None, None]:
"""Ensure all the items in <iterable> are processed by <func> in the pool.

Expand All @@ -138,9 +161,10 @@ def ensure_tasks(
"""
with self._start_lock:
if self._started:
raise ValueError("ensure_tasks cannot be started more than once")
if self._shutdown or self._broken:
raise ValueError("threadpool is shutdown or broken, abort")
try:
raise ValueError("ensure_tasks cannot be started more than once")
finally: # do not hold refs to input params
del self, func, iterable
self._started = True

# ------ dispatch tasks from iterable ------ #
Expand All @@ -160,23 +184,10 @@ def _dispatcher() -> None:
self._total_task_num = _tasks_count
logger.info(f"finish dispatch {_tasks_count} tasks")

self.submit(_dispatcher)
threading.Thread(target=_dispatcher, daemon=True).start()

# ------ ensure all tasks are finished ------ #
while True:
if self._shutdown or self._broken or concurrent_fut_thread._shutdown:
logger.warning(
f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}"
)
raise TasksEnsureFailed

try:
yield self._fut_queue.get_nowait()
except Empty:
if (
self._total_task_num == 0
or self._finished_task != self._total_task_num
):
time.sleep(self.ensure_tasks_pull_interval)
continue
return
# NOTE: also see base.Executor.map method, let the yield hidden in
# a generator so that the first fut will be dispatched before
# we start to get from fut_queue.
return self._fut_gen(ensure_tasks_pull_interval)
68 changes: 46 additions & 22 deletions tests/test_otaclient_common/test_retry_task_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import logging
import random
import threading
import time

import pytest
Expand All @@ -25,39 +26,48 @@

logger = logging.getLogger(__name__)

# ------ test setup ------ #
WAIT_CONST = 100_000_000
TASKS_COUNT = 2000
MAX_CONCURRENT = 128
MAX_WAIT_BEFORE_SUCCESS = 10
THREAD_INIT_MSG = "thread init message"


class _RetryTaskMapTestErr(Exception):
""""""


def _thread_initializer(msg: str) -> None:
"""For testing thread worker initializer."""
thread_native_id = threading.get_native_id()
logger.info(f"thread worker #{thread_native_id} initialized: {msg}")


class TestRetryTaskMap:
WAIT_CONST = 100_000_000
TASKS_COUNT = 2000
MAX_CONCURRENT = 128
MAX_WAIT_BEFORE_SUCCESS = 10

@pytest.fixture(autouse=True)
def setup(self):
self._start_time = time.time()
self._success_wait_dict = {
idx: random.randint(0, self.MAX_WAIT_BEFORE_SUCCESS)
for idx in range(self.TASKS_COUNT)
idx: random.randint(0, MAX_WAIT_BEFORE_SUCCESS)
for idx in range(TASKS_COUNT)
}
self._succeeded_tasks = [False for _ in range(self.TASKS_COUNT)]
self._succeeded_tasks = [False for _ in range(TASKS_COUNT)]

def workload_aways_failed(self, idx: int) -> int:
time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST)
time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST)
raise _RetryTaskMapTestErr

def workload_failed_and_then_succeed(self, idx: int) -> int:
time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST)
time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST)
if time.time() > self._start_time + self._success_wait_dict[idx]:
self._succeeded_tasks[idx] = True
return idx
raise _RetryTaskMapTestErr

def workload_succeed(self, idx: int) -> int:
time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST)
time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST)
self._succeeded_tasks[idx] = True
return idx

Expand All @@ -69,48 +79,62 @@ def _exit_on_exceed_max_count():
raise ValueError(f"{failure_count=} > {MAX_RETRY=}")

with retry_task_map.ThreadPoolExecutorWithRetry(
max_concurrent=self.MAX_CONCURRENT,
max_concurrent=MAX_CONCURRENT,
watchdog_func=_exit_on_exceed_max_count,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
) as executor:
with pytest.raises(retry_task_map.TasksEnsureFailed):
for _fut in executor.ensure_tasks(
self.workload_aways_failed, range(self.TASKS_COUNT)
self.workload_aways_failed, range(TASKS_COUNT)
):
if _fut.exception():
failure_count += 1

def test_retry_exceed_retry_limit(self):
MAX_TOTAL_RETRY = 200
failure_count = 0
with retry_task_map.ThreadPoolExecutorWithRetry(
max_concurrent=self.MAX_CONCURRENT, max_total_retry=200
max_concurrent=MAX_CONCURRENT,
max_total_retry=MAX_TOTAL_RETRY,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
) as executor:
with pytest.raises(retry_task_map.TasksEnsureFailed):
for _ in executor.ensure_tasks(
self.workload_aways_failed, range(self.TASKS_COUNT)
for _fut in executor.ensure_tasks(
self.workload_aways_failed, range(TASKS_COUNT)
):
pass
if _fut.exception():
failure_count += 1

assert failure_count >= MAX_TOTAL_RETRY

def test_retry_finally_succeeded(self):
count = 0
with retry_task_map.ThreadPoolExecutorWithRetry(
max_concurrent=self.MAX_CONCURRENT
max_concurrent=MAX_CONCURRENT,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
) as executor:
for _fut in executor.ensure_tasks(
self.workload_failed_and_then_succeed, range(self.TASKS_COUNT)
self.workload_failed_and_then_succeed, range(TASKS_COUNT)
):
if not _fut.exception():
count += 1
assert all(self._succeeded_tasks)
assert self.TASKS_COUNT == count
assert TASKS_COUNT == count

def test_succeeded_in_one_try(self):
count = 0
with retry_task_map.ThreadPoolExecutorWithRetry(
max_concurrent=self.MAX_CONCURRENT
max_concurrent=MAX_CONCURRENT,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
) as executor:
for _fut in executor.ensure_tasks(
self.workload_succeed, range(self.TASKS_COUNT)
self.workload_succeed, range(TASKS_COUNT)
):
if not _fut.exception():
count += 1
assert all(self._succeeded_tasks)
assert self.TASKS_COUNT == count
assert TASKS_COUNT == count
Loading