Skip to content

Commit

Permalink
merge upstream retry_task_map
Browse files Browse the repository at this point in the history
  • Loading branch information
Bodong-Yang committed Jun 24, 2024
1 parent 570fa72 commit 9e3708e
Showing 1 changed file with 65 additions and 66 deletions.
131 changes: 65 additions & 66 deletions src/otaclient_common/retry_task_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from __future__ import annotations

import atexit
import concurrent.futures.thread as concurrent_fut_thread
import contextlib
import itertools
import logging
Expand All @@ -30,16 +30,6 @@

logger = logging.getLogger(__name__)

_retry_task_map_global_shutdown = False


def _python_exit():
global _retry_task_map_global_shutdown
_retry_task_map_global_shutdown = True


atexit.register(_python_exit)


class TasksEnsureFailed(Exception):
"""Exception for tasks ensuring failed."""
Expand All @@ -55,7 +45,6 @@ def __init__(
thread_name_prefix: str = "",
watchdog_func: Optional[Callable] = None,
watchdog_check_interval: int = 3, # seconds
ensure_tasks_pull_interval: int = 1, # second
initializer: Callable[..., Any] | None = None,
initargs: tuple = (),
) -> None:
Expand All @@ -69,11 +58,11 @@ def __init__(
watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when
this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None.
watchdog_check_interval (int): Defaults to 3(seconds).
ensure_tasks_pull_interval (int): Defaults to 1(second).
initializer (Callable[..., Any] | None): The same <initializer> param passed through to ThreadPoolExecutor.
Defaults to None.
initargs (tuple): The same <initargs> param passed through to ThreadPoolExecutor.
Defaults to ().
"""
self.max_total_retry = max_total_retry
self.ensure_tasks_pull_interval = ensure_tasks_pull_interval

self._start_lock, self._started = threading.Lock(), False
self._total_task_num = 0
self._finished_task_counter = itertools.count(start=1)
Expand All @@ -90,22 +79,23 @@ def __init__(
initargs=initargs,
)

threading.Thread(
target=self._watchdog,
args=(watchdog_func, watchdog_check_interval),
daemon=True,
).start()
if max_total_retry or callable(watchdog_func):
threading.Thread(
target=self._watchdog,
args=(max_total_retry, watchdog_func, watchdog_check_interval),
daemon=True,
).start()

def _watchdog(
self, watchdog_func: Callable[..., Any], watchdog_check_interval: int
self,
max_retry: int | None,
watchdog_func: Callable[..., Any] | None,
interval: int,
) -> None:
"""Watchdog will shutdown the threadpool on certain conditions being met."""
if not self.max_total_retry and not callable(watchdog_func):
return # no need to run watchdog thread if not checks are performed

while not self._shutdown and not _retry_task_map_global_shutdown:
if self.max_total_retry and self._retry_count > self.max_total_retry:
logger.warning(f"exceed {self.max_total_retry=}, abort")
while not self._shutdown and not concurrent_fut_thread._shutdown:
if max_retry and self._retry_count > max_retry:
logger.warning(f"exceed {max_retry=}, abort")
return self.shutdown(wait=True)

if callable(watchdog_func):
Expand All @@ -114,21 +104,21 @@ def _watchdog(
except Exception as e:
logger.warning(f"custom watchdog func failed: {e!r}, abort")
return self.shutdown(wait=True)
time.sleep(watchdog_check_interval)
time.sleep(interval)

def _task_done_cb(
self, fut: Future[Any], /, *, item: T, func: Callable[[T], Any]
) -> None:
self._concurrent_semaphore.release() # NOTE: always release se first
self._fut_queue.put_nowait(fut)
self._concurrent_semaphore.release()

# ------ on task succeeded ------ #
if not fut.exception():
self._finished_task = next(self._finished_task_counter)
return

# ------ on threadpool shutdown(by watchdog) ------ #
if self._shutdown or _retry_task_map_global_shutdown:
if self._shutdown or self._broken:
return

# ------ on task failed ------ #
Expand All @@ -138,23 +128,31 @@ def _task_done_cb(
partial(self._task_done_cb, item=item, func=func)
)

def _dispatcher(self, func: Callable[[T], RT], iterable: Iterable[T]) -> None:
try:
for _tasks_count, item in enumerate(iterable, start=1):
self._concurrent_semaphore.acquire()
# NOTE: on pool shutdown, the submit method will throw exceptions
fut = self.submit(func, item)
fut.add_done_callback(partial(self._task_done_cb, item=item, func=func))
except Exception as e:
logger.error(f"tasks dispatcher failed: {e!r}, abort")
self.shutdown(wait=True)
return
def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]:
while True:
if self._shutdown or self._broken or concurrent_fut_thread._shutdown:
logger.warning(
f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}"
)
raise TasksEnsureFailed

self._total_task_num = _tasks_count
logger.info(f"finish dispatch {_tasks_count} tasks")
try:
yield self._fut_queue.get_nowait()
except Empty:
if (
self._total_task_num == 0
or self._finished_task != self._total_task_num
):
time.sleep(interval)
continue
return

def ensure_tasks(
self, func: Callable[[T], RT], iterable: Iterable[T]
self,
func: Callable[[T], RT],
iterable: Iterable[T],
*,
ensure_tasks_pull_interval: int = 1,
) -> Generator[Future[RT], None, None]:
"""Ensure all the items in <iterable> are processed by <func> in the pool.
Expand All @@ -171,33 +169,34 @@ def ensure_tasks(
The Future instance of each processed tasks.
"""
with self._start_lock:
if self._started or self._shutdown or _retry_task_map_global_shutdown:
if self._started:
try:
raise ValueError(
"pool shutdowned or ensure_tasks cannot be started more than once"
)
raise ValueError("ensure_tasks cannot be started more than once")
finally: # do not hold refs to input params
del self, func, iterable
self._started = True

# ------ dispatch tasks from iterable ------ #
threading.Thread(
target=self._dispatcher, args=(func, iterable), daemon=True
).start()
def _dispatcher() -> None:
try:
for _tasks_count, item in enumerate(iterable, start=1):
self._concurrent_semaphore.acquire()
fut = self.submit(func, item)
fut.add_done_callback(
partial(self._task_done_cb, item=item, func=func)
)
except Exception as e:
logger.error(f"tasks dispatcher failed: {e!r}, abort")
self.shutdown(wait=True)
return

# ------ ensure all tasks are finished ------ #
while self._total_task_num == 0 or self._finished_task != self._total_task_num:
# shutdown by upper caller or interpreter exits
if self._shutdown or _retry_task_map_global_shutdown:
_err_msg = f"failed to ensure all tasks, {self._finished_task=}, {self._total_task_num=}"
logger.warning(_err_msg)
self._total_task_num = _tasks_count
logger.info(f"finish dispatch {_tasks_count} tasks")

try:
raise TasksEnsureFailed(_err_msg)
finally:
del self, func, iterable
threading.Thread(target=_dispatcher, daemon=True).start()

try:
yield self._fut_queue.get_nowait()
except Empty:
time.sleep(self.ensure_tasks_pull_interval)
# ------ ensure all tasks are finished ------ #
# NOTE: also see base.Executor.map method, let the yield hidden in
# a generator so that the first fut will be dispatched before
# we start to get from fut_queue.
return self._fut_gen(ensure_tasks_pull_interval)

0 comments on commit 9e3708e

Please sign in to comment.