Skip to content

Commit

Permalink
Merge branch 'main' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
Bodong-Yang authored Dec 3, 2024
2 parents 001755f + 19c4c5c commit b0c2f17
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 12 deletions.
77 changes: 66 additions & 11 deletions src/otaclient_common/retry_task_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,23 @@
from queue import Empty, SimpleQueue
from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Optional

from typing_extensions import ParamSpec

from otaclient_common.common import wait_with_backoff
from otaclient_common.typing import RT, T

P = ParamSpec("P")

logger = logging.getLogger(__name__)


class TasksEnsureFailed(Exception):
"""Exception for tasks ensuring failed."""


CONTINUES_FAILURE_COUNT_ATTRNAME = "continues_failed_count"


class _ThreadPoolExecutorWithRetry(ThreadPoolExecutor):

def __init__(
Expand All @@ -47,6 +55,8 @@ def __init__(
watchdog_check_interval: int = 3, # seconds
initializer: Callable[..., Any] | None = None,
initargs: tuple = (),
backoff_factor: float = 0.01,
backoff_max: float = 1,
) -> None:
self._start_lock, self._started = threading.Lock(), False
self._total_task_num = 0
Expand All @@ -57,6 +67,8 @@ def __init__(
no tasks, the task execution gen should stop immediately.
3. only value >=1 is valid.
"""
self.backoff_factor = backoff_factor
self.backoff_max = backoff_max

self._retry_counter = itertools.count(start=1)
self._retry_count = 0
Expand All @@ -70,13 +82,26 @@ def __init__(
if callable(watchdog_func):
self._checker_funcs.append(watchdog_func)

self._thread_local = threading.local()
super().__init__(
max_workers=max_workers,
thread_name_prefix=thread_name_prefix,
initializer=initializer,
initializer=self._rtm_initializer_gen(initializer),
initargs=initargs,
)

def _rtm_initializer_gen(
self, _input_initializer: Callable[P, RT] | None
) -> Callable[P, None]:
def _real_initializer(*args: P.args, **kwargs: P.kwargs) -> None:
_thread_local = self._thread_local
setattr(_thread_local, CONTINUES_FAILURE_COUNT_ATTRNAME, 0)

if callable(_input_initializer):
_input_initializer(*args, **kwargs)

return _real_initializer

def _max_retry_check(self, max_total_retry: int) -> None:
if self._retry_count > max_total_retry:
raise TasksEnsureFailed("exceed max retry count, abort")
Expand Down Expand Up @@ -110,16 +135,40 @@ def _task_done_cb(
return # on shutdown, no need to put done fut into fut_queue
self._fut_queue.put_nowait(fut)

# ------ on task failed ------ #
if fut.exception():
self._retry_count = next(self._retry_counter)
try: # try to re-schedule the failed task
self.submit(func, item).add_done_callback(
partial(self._task_done_cb, item=item, func=func)
)
except Exception: # if re-schedule doesn't happen, release se
self._concurrent_semaphore.release()
else: # release semaphore when succeeded
_thread_local = self._thread_local

# release semaphore only on success
# reset continues failure count on success
if not fut.exception():
self._concurrent_semaphore.release()
_thread_local.continues_failed_count = 0
return

# NOTE: when for some reason the continues_failed_count is gone,
# handle the AttributeError here and re-assign the count.
try:
_continues_failed_count = getattr(
_thread_local, CONTINUES_FAILURE_COUNT_ATTRNAME
)
except AttributeError:
_continues_failed_count = 0

_continues_failed_count += 1
setattr(
_thread_local, CONTINUES_FAILURE_COUNT_ATTRNAME, _continues_failed_count
)
wait_with_backoff(
_continues_failed_count,
_backoff_factor=self.backoff_factor,
_backoff_max=self.backoff_max,
)

self._retry_count = next(self._retry_counter)
try: # try to re-schedule the failed task
self.submit(func, item).add_done_callback(
partial(self._task_done_cb, item=item, func=func)
)
except Exception: # if re-schedule doesn't happen, release se
self._concurrent_semaphore.release()

def _fut_gen(self, interval: int) -> Generator[Future[Any], Any, None]:
Expand Down Expand Up @@ -221,6 +270,8 @@ def __init__(
watchdog_check_interval: int = 3, # seconds
initializer: Callable[..., Any] | None = None,
initargs: tuple = (),
backoff_factor: float = 0.01,
backoff_max: float = 1,
) -> None:
"""Initialize a ThreadPoolExecutorWithRetry instance.
Expand All @@ -236,6 +287,10 @@ def __init__(
Defaults to None.
initargs (tuple): The same <initargs> param passed through to ThreadPoolExecutor.
Defaults to ().
backoff_factor (float): The backoff factor on thread-scope backoff wait for failed tasks rescheduling.
Defaults to 0.01.
backoff_max (float): The backoff max on thread-scope backoff wait for failed tasks rescheduling.
Defaults to 1.
"""
raise NotImplementedError

Expand Down
12 changes: 11 additions & 1 deletion tests/test_otaclient_common/test_retry_task_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
MAX_CONCURRENT = 128
MAX_WAIT_BEFORE_SUCCESS = 10
THREAD_INIT_MSG = "thread init message"
BACKOFF_FACTOR = 0.001 # for faster test
BACKOFF_MAX = 0.1


class _RetryTaskMapTestErr(Exception):
Expand All @@ -47,7 +49,7 @@ def _thread_initializer(msg: str) -> None:
class TestRetryTaskMap:

@pytest.fixture(autouse=True)
def setup(self):
def setup(self) -> None:
self._start_time = time.time()
self._success_wait_dict = {
idx: random.randint(0, MAX_WAIT_BEFORE_SUCCESS)
Expand Down Expand Up @@ -83,6 +85,8 @@ def _exit_on_exceed_max_count():
watchdog_func=_exit_on_exceed_max_count,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
backoff_factor=BACKOFF_FACTOR,
backoff_max=BACKOFF_MAX,
) as executor:
with pytest.raises(retry_task_map.TasksEnsureFailed):
for _fut in executor.ensure_tasks(
Expand All @@ -99,6 +103,8 @@ def test_retry_exceed_retry_limit(self):
max_total_retry=MAX_TOTAL_RETRY,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
backoff_factor=BACKOFF_FACTOR,
backoff_max=BACKOFF_MAX,
) as executor:
with pytest.raises(retry_task_map.TasksEnsureFailed):
for _fut in executor.ensure_tasks(
Expand All @@ -115,6 +121,8 @@ def test_retry_finally_succeeded(self):
max_concurrent=MAX_CONCURRENT,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
backoff_factor=BACKOFF_FACTOR,
backoff_max=BACKOFF_MAX,
) as executor:
for _fut in executor.ensure_tasks(
self.workload_failed_and_then_succeed, range(TASKS_COUNT)
Expand All @@ -130,6 +138,8 @@ def test_succeeded_in_one_try(self):
max_concurrent=MAX_CONCURRENT,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
backoff_factor=BACKOFF_FACTOR,
backoff_max=BACKOFF_MAX,
) as executor:
for _fut in executor.ensure_tasks(
self.workload_succeed, range(TASKS_COUNT)
Expand Down

0 comments on commit b0c2f17

Please sign in to comment.