Skip to content

Commit

Permalink
rtm: watchdog becomes a method, fut gen becomes a method
Browse files Browse the repository at this point in the history
  • Loading branch information
Bodong-Yang committed Jun 23, 2024
1 parent fe82b0a commit 32bae00
Showing 1 changed file with 56 additions and 41 deletions.
97 changes: 56 additions & 41 deletions src/otaclient_common/retry_task_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
thread_name_prefix: str = "",
watchdog_func: Optional[Callable] = None,
watchdog_check_interval: int = 3, # seconds
ensure_tasks_pull_interval: int = 1, # second
initializer: Callable[..., Any] | None = None,
initargs: tuple = (),
) -> None:
Expand All @@ -59,11 +58,11 @@ def __init__(
watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when
this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None.
watchdog_check_interval (int): Defaults to 3(seconds).
ensure_tasks_pull_interval (int): Defaults to 1(second).
initializer (Callable[..., Any] | None): The same <initializer> param passed through to ThreadPoolExecutor.
Defaults to None.
initargs (tuple): The same <initargs> param passed through to ThreadPoolExecutor.
Defaults to ().
"""
self.max_total_retry = max_total_retry
self.ensure_tasks_pull_interval = ensure_tasks_pull_interval

self._start_lock, self._started = threading.Lock(), False
self._total_task_num = 0
self._finished_task_counter = itertools.count(start=1)
Expand All @@ -80,23 +79,32 @@ def __init__(
initargs=initargs,
)

def _watchdog() -> None:
"""Watchdog will shutdown the threadpool on certain conditions being met."""
while not self._shutdown and not concurrent_fut_thread._shutdown:
if self.max_total_retry and self._retry_count > self.max_total_retry:
logger.warning(f"exceed {self.max_total_retry=}, abort")
return self.shutdown(wait=True)
if max_total_retry or callable(watchdog_func):
threading.Thread(
target=self._watchdog,
args=(max_total_retry, watchdog_func, watchdog_check_interval),
daemon=True,
).start()

if callable(watchdog_func):
try:
watchdog_func()
except Exception as e:
logger.warning(f"custom watchdog func failed: {e!r}, abort")
return self.shutdown(wait=True)
time.sleep(watchdog_check_interval)
def _watchdog(
self,
max_retry: int | None,
watchdog_func: Callable[..., Any] | None,
interval: int,
) -> None:
"""Watchdog will shutdown the threadpool on certain conditions being met."""
while not self._shutdown and not concurrent_fut_thread._shutdown:
if max_retry and self._retry_count > max_retry:
logger.warning(f"exceed {max_retry=}, abort")
return self.shutdown(wait=True)

if self.max_total_retry or callable(watchdog_func):
threading.Thread(target=_watchdog, daemon=True).start()
if callable(watchdog_func):
try:
watchdog_func()
except Exception as e:
logger.warning(f"custom watchdog func failed: {e!r}, abort")
return self.shutdown(wait=True)
time.sleep(interval)

def _task_done_cb(
self, fut: Future[Any], /, *, item: T, func: Callable[[T], Any]
Expand All @@ -120,8 +128,31 @@ def _task_done_cb(
partial(self._task_done_cb, item=item, func=func)
)

def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]:
while True:
if self._shutdown or self._broken or concurrent_fut_thread._shutdown:
logger.warning(
f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}"
)
raise TasksEnsureFailed

try:
yield self._fut_queue.get_nowait()
except Empty:
if (
self._total_task_num == 0
or self._finished_task != self._total_task_num
):
time.sleep(interval)
continue
return

def ensure_tasks(
self, func: Callable[[T], RT], iterable: Iterable[T]
self,
func: Callable[[T], RT],
iterable: Iterable[T],
*,
ensure_tasks_pull_interval: int = 1,
) -> Generator[Future[RT], None, None]:
"""Ensure all the items in <iterable> are processed by <func> in the pool.
Expand Down Expand Up @@ -165,23 +196,7 @@ def _dispatcher() -> None:
threading.Thread(target=_dispatcher, daemon=True).start()

# ------ ensure all tasks are finished ------ #
while True:
if self._shutdown or self._broken or concurrent_fut_thread._shutdown:
logger.warning(
f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}"
)
try:
raise TasksEnsureFailed
finally:
del self, func, iterable

try:
yield self._fut_queue.get_nowait()
except Empty:
if (
self._total_task_num == 0
or self._finished_task != self._total_task_num
):
time.sleep(self.ensure_tasks_pull_interval)
continue
return
# NOTE: also see base.Executor.map method, let the yield hidden in
# a generator so that the first fut will be dispatched before
# we start to get from fut_queue.
return self._fut_gen(ensure_tasks_pull_interval)

0 comments on commit 32bae00

Please sign in to comment.