Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(backport v3.8.x): backport PR#395: retry_task_map: fix mem leak when network is totally cut off #436

Merged
merged 3 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ on:
pull_request:
branches:
- main
- v*
push:
branches:
- main
- v*
paths:
- "src/**"
- "tests/**"
Expand Down
2 changes: 1 addition & 1 deletion src/otaclient/app/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class BaseConfig(_InternalSettings):
"otaclient": INFO,
"otaclient_api": INFO,
"otaclient_common": INFO,
"otaproxy": INFO,
"ota_proxy": INFO,
}
LOG_FORMAT = (
"[%(asctime)s][%(levelname)s]-%(name)s:%(funcName)s:%(lineno)d,%(message)s"
Expand Down
177 changes: 122 additions & 55 deletions src/otaclient_common/retry_task_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from concurrent.futures import Future, ThreadPoolExecutor
from functools import partial
from queue import Empty, SimpleQueue
from typing import Any, Callable, Generator, Iterable, Optional
from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Optional

from otaclient_common.typing import RT, T

Expand All @@ -35,7 +35,7 @@ class TasksEnsureFailed(Exception):
"""Exception for tasks ensuring failed."""


class ThreadPoolExecutorWithRetry(ThreadPoolExecutor):
class _ThreadPoolExecutorWithRetry(ThreadPoolExecutor):

def __init__(
self,
Expand All @@ -48,21 +48,6 @@ def __init__(
initializer: Callable[..., Any] | None = None,
initargs: tuple = (),
) -> None:
"""Initialize a ThreadPoolExecutorWithRetry instance.

Args:
max_concurrent (int): Limit the number pending scheduled tasks.
max_workers (Optional[int], optional): Max number of worker threads in the pool. Defaults to None.
max_total_retry (Optional[int], optional): Max total retry counts before abort. Defaults to None.
thread_name_prefix (str, optional): Defaults to "".
watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when
this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None.
watchdog_check_interval (int): Defaults to 3(seconds).
initializer (Callable[..., Any] | None): The same <initializer> param passed through to ThreadPoolExecutor.
Defaults to None.
initargs (tuple): The same <initargs> param passed through to ThreadPoolExecutor.
Defaults to ().
"""
self._start_lock, self._started = threading.Lock(), False
self._total_task_num = 0
"""
Expand All @@ -78,65 +63,80 @@ def __init__(
self._concurrent_semaphore = threading.Semaphore(max_concurrent)
self._fut_queue: SimpleQueue[Future[Any]] = SimpleQueue()

self._watchdog_check_interval = watchdog_check_interval
self._checker_funcs: list[Callable[[], Any]] = []
if isinstance(max_total_retry, int) and max_total_retry > 0:
self._checker_funcs.append(partial(self._max_retry_check, max_total_retry))
if callable(watchdog_func):
self._checker_funcs.append(watchdog_func)

super().__init__(
max_workers=max_workers,
thread_name_prefix=thread_name_prefix,
initializer=initializer,
initargs=initargs,
)

if max_total_retry or callable(watchdog_func):
threading.Thread(
target=self._watchdog,
args=(max_total_retry, watchdog_func, watchdog_check_interval),
daemon=True,
).start()
def _max_retry_check(self, max_total_retry: int) -> None:
if self._retry_count > max_total_retry:
raise TasksEnsureFailed("exceed max retry count, abort")

def _watchdog(
self,
max_retry: int | None,
watchdog_func: Callable[..., Any] | None,
_checker_funcs: list[Callable[[], Any]],
interval: int,
) -> None:
"""Watchdog will shutdown the threadpool on certain conditions being met."""
while not self._shutdown and not concurrent_fut_thread._shutdown:
if max_retry and self._retry_count > max_retry:
logger.warning(f"exceed {max_retry=}, abort")
return self.shutdown(wait=True)

if callable(watchdog_func):
try:
watchdog_func()
except Exception as e:
logger.warning(f"custom watchdog func failed: {e!r}, abort")
return self.shutdown(wait=True)
while not (self._shutdown or self._broken or concurrent_fut_thread._shutdown):
time.sleep(interval)
try:
for _func in _checker_funcs:
_func()
except Exception as e:
logger.warning(
f"watchdog failed: {e!r}, shutdown the pool and draining the workitem queue on shutdown.."
)
self.shutdown(wait=False)
# drain the worker queues to cancel all the futs
with contextlib.suppress(Empty):
while True:
self._work_queue.get_nowait()

def _task_done_cb(
self, fut: Future[Any], /, *, item: T, func: Callable[[T], Any]
) -> None:
self._concurrent_semaphore.release() # always release se first
if self._shutdown or self._broken or concurrent_fut_thread._shutdown:
self._concurrent_semaphore.release() # on shutdown, always release a se
return # on shutdown, no need to put done fut into fut_queue
self._fut_queue.put_nowait(fut)

# ------ on task failed ------ #
if fut.exception():
self._retry_count = next(self._retry_counter)
with contextlib.suppress(Exception): # on threadpool shutdown
try: # try to re-schedule the failed task
self.submit(func, item).add_done_callback(
partial(self._task_done_cb, item=item, func=func)
)
except Exception: # if re-schedule doesn't happen, release se
self._concurrent_semaphore.release()
else: # release semaphore when succeeded
self._concurrent_semaphore.release()

def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]:
"""Generator which yields the done future, controlled by the caller."""
finished_tasks = 0
while finished_tasks == 0 or finished_tasks != self._total_task_num:
if self._total_task_num < 0:
return

if self._shutdown or self._broken or concurrent_fut_thread._shutdown:
logger.warning(
f"failed to ensure all tasks, {finished_tasks=}, {self._total_task_num=}"
f"dispatcher exits on threadpool shutdown, {finished_tasks=}, {self._total_task_num=}"
)
raise TasksEnsureFailed
with contextlib.suppress(Empty):
while True: # drain the _fut_queue
self._fut_queue.get_nowait()
raise TasksEnsureFailed # raise exc to upper caller

try:
done_fut = self._fut_queue.get_nowait()
Expand All @@ -153,20 +153,6 @@ def ensure_tasks(
*,
ensure_tasks_pull_interval: int = 1,
) -> Generator[Future[RT], None, None]:
"""Ensure all the items in <iterable> are processed by <func> in the pool.

Args:
func (Callable[[T], RT]): The function to take the item from <iterable>.
iterable (Iterable[T]): The iterable of items to be processed by <func>.

Raises:
ValueError: If the pool is shutdown or broken, or this method has already
being called once.
TasksEnsureFailed: If failed to ensure all the tasks are finished.

Yields:
The Future instance of each processed tasks.
"""
with self._start_lock:
if self._started:
try:
Expand All @@ -175,19 +161,34 @@ def ensure_tasks(
del self, func, iterable
self._started = True

if self._checker_funcs:
threading.Thread(
target=self._watchdog,
args=(self._checker_funcs, self._watchdog_check_interval),
daemon=True,
).start()

# ------ dispatch tasks from iterable ------ #
def _dispatcher() -> None:
_tasks_count = -1 # means no task is scheduled
try:
for _tasks_count, item in enumerate(iterable, start=1):
if (
self._shutdown
or self._broken
or concurrent_fut_thread._shutdown
):
logger.warning("threadpool is closing, exit")
return # directly exit on shutdown

self._concurrent_semaphore.acquire()
fut = self.submit(func, item)
fut.add_done_callback(
partial(self._task_done_cb, item=item, func=func)
)
except Exception as e:
logger.error(f"tasks dispatcher failed: {e!r}, abort")
self.shutdown(wait=True)
self.shutdown(wait=False)
return

self._total_task_num = _tasks_count
Expand All @@ -203,3 +204,69 @@ def _dispatcher() -> None:
# a generator so that the first fut will be dispatched before
# we start to get from fut_queue.
return self._fut_gen(ensure_tasks_pull_interval)


# only expose APIs we want to exposed
if TYPE_CHECKING:

class ThreadPoolExecutorWithRetry:

def __init__(
self,
max_concurrent: int,
max_workers: Optional[int] = None,
max_total_retry: Optional[int] = None,
thread_name_prefix: str = "",
watchdog_func: Optional[Callable] = None,
watchdog_check_interval: int = 3, # seconds
initializer: Callable[..., Any] | None = None,
initargs: tuple = (),
) -> None:
"""Initialize a ThreadPoolExecutorWithRetry instance.

Args:
max_concurrent (int): Limit the number pending scheduled tasks.
max_workers (Optional[int], optional): Max number of worker threads in the pool. Defaults to None.
max_total_retry (Optional[int], optional): Max total retry counts before abort. Defaults to None.
thread_name_prefix (str, optional): Defaults to "".
watchdog_func (Optional[Callable]): A custom func to be called on watchdog thread, when
this func raises exception, the watchdog will interrupt the tasks execution. Defaults to None.
watchdog_check_interval (int): Defaults to 3(seconds).
initializer (Callable[..., Any] | None): The same <initializer> param passed through to ThreadPoolExecutor.
Defaults to None.
initargs (tuple): The same <initargs> param passed through to ThreadPoolExecutor.
Defaults to ().
"""
raise NotImplementedError

def ensure_tasks(
self,
func: Callable[[T], RT],
iterable: Iterable[T],
*,
ensure_tasks_pull_interval: int = 1,
) -> Generator[Future[RT], None, None]:
"""Ensure all the items in <iterable> are processed by <func> in the pool.

Args:
func (Callable[[T], RT]): The function to take the item from <iterable>.
iterable (Iterable[T]): The iterable of items to be processed by <func>.

Raises:
ValueError: If the pool is shutdown or broken, or this method has already
being called once.
TasksEnsureFailed: If failed to ensure all the tasks are finished.

Yields:
The Future instance of each processed tasks.
"""
raise NotImplementedError

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError

else:
ThreadPoolExecutorWithRetry = _ThreadPoolExecutorWithRetry
Loading