Skip to content

Commit

Permalink
dispatcher: shutdown the threadpool on exception
Browse files Browse the repository at this point in the history
  • Loading branch information
Bodong-Yang committed Jun 5, 2024
1 parent 83cc92b commit 1553e46
Showing 1 changed file with 27 additions and 17 deletions.
44 changes: 27 additions & 17 deletions src/otaclient_common/retry_task_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,33 +32,35 @@


class TasksEnsureFailed(Exception):
pass
"""Exception for tasks ensuring failed."""


class ThreadPoolExecutorWithRetry(ThreadPoolExecutor):

WATCH_DOG_CHECK_INTERVAL = 3
ENSURE_TASKS_PULL_INTERVAL = 1

def __init__(
self,
max_concurrent: int,
max_workers: Optional[int] = None,
max_total_retry: Optional[int] = None,
thread_name_prefix: str = "",
watchdog_func: Optional[Callable] = None,
watchdog_check_interval: int = 3, # seconds
ensure_tasks_pull_interval: int = 1, # second
) -> None:
"""Initialize a ThreadPoolExecutorWithRetry instance.
Args:
max_concurrent (int, optional): How many tasks should be kept in the memory.
max_concurrent (int): Limit the number pending scheduled tasks.
max_workers (Optional[int], optional): Max number of worker threads in the pool. Defaults to None.
max_total_retry (Optional[int], optional): Max total retry counts before abort. Defaults to None.
thread_name_prefix (str, optional): Defaults to "".
watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, break threadpool when
this func raises exception. Defaults to None.
watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when
this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None.
watchdog_check_interval (int): Defaults to 3(seconds).
ensure_tasks_pull_interval (int): Defaults to 1(second).
"""
self.max_total_retry = max_total_retry
self.ensure_tasks_pull_interval = ensure_tasks_pull_interval

self._start_lock, self._started = threading.Lock(), False
self._total_task_num = 0
Expand Down Expand Up @@ -88,7 +90,7 @@ def _watchdog() -> None:
except Exception as e:
logger.warning(f"custom watchdog func failed: {e!r}, abort")
return self.shutdown(wait=True)
time.sleep(self.WATCH_DOG_CHECK_INTERVAL)
time.sleep(watchdog_check_interval)

self.submit(_watchdog)

Expand All @@ -111,7 +113,7 @@ def _task_done_cb(
self._retry_count = next(self._retry_counter)
self._fut_queue.put_nowait(fut)

with contextlib.suppress(Exception):
with contextlib.suppress(Exception): # on threadpool shutdown
self.submit(func, item).add_done_callback(
partial(self._task_done_cb, item=item, func=func)
)
Expand Down Expand Up @@ -141,15 +143,23 @@ def ensure_tasks(
self._started = True

# ------ dispatch tasks from iterable ------ #
def _dispatcher():
def _dispatcher() -> None:
_fut_queue = self._fut_queue
for _tasks_count, item in enumerate(iterable, start=1):
self._concurrent_semaphore.acquire()
fut = self.submit(func, item)
fut.add_done_callback(partial(self._task_done_cb, item=item, func=func))
_fut_queue.put_nowait(fut)
try:
for _tasks_count, item in enumerate(iterable, start=1):
self._concurrent_semaphore.acquire()
fut = self.submit(func, item)
fut.add_done_callback(
partial(self._task_done_cb, item=item, func=func)
)
_fut_queue.put_nowait(fut)
except Exception as e:
logger.error(f"tasks dispatcher failed: {e!r}, abort")
self.shutdown(wait=True)
return

self._total_task_num = _tasks_count
logger.info(f"finish dispatch {_tasks_count} of tasks")
logger.info(f"finish dispatch {_tasks_count} tasks")

self.submit(_dispatcher)

Expand All @@ -158,7 +168,7 @@ def _dispatcher():
try:
yield self._fut_queue.get_nowait()
except Empty:
time.sleep(self.ENSURE_TASKS_PULL_INTERVAL)
time.sleep(self.ensure_tasks_pull_interval)
continue

if self._shutdown or self._broken:
Expand Down

0 comments on commit 1553e46

Please sign in to comment.