Skip to content

Commit

Permalink
test_retry_task_map: tests now cover thread worker initializer
Browse files Browse the repository at this point in the history
  • Loading branch information
Bodong-Yang committed Jun 24, 2024
1 parent fd4f1af commit 4da2b8e
Showing 1 changed file with 46 additions and 22 deletions.
68 changes: 46 additions & 22 deletions tests/test_otaclient_common/test_retry_task_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import logging
import random
import threading
import time

import pytest
Expand All @@ -25,39 +26,48 @@

logger = logging.getLogger(__name__)

# ------ test setup ------ #
WAIT_CONST = 100_000_000
TASKS_COUNT = 2000
MAX_CONCURRENT = 128
MAX_WAIT_BEFORE_SUCCESS = 10
THREAD_INIT_MSG = "thread init message"


class _RetryTaskMapTestErr(Exception):
""""""


def _thread_initializer(msg: str) -> None:
"""For testing thread worker initializer."""
thread_native_id = threading.get_native_id()
logger.info(f"thread worker #{thread_native_id} initialized: {msg}")


class TestRetryTaskMap:
WAIT_CONST = 100_000_000
TASKS_COUNT = 2000
MAX_CONCURRENT = 128
MAX_WAIT_BEFORE_SUCCESS = 10

@pytest.fixture(autouse=True)
def setup(self):
self._start_time = time.time()
self._success_wait_dict = {
idx: random.randint(0, self.MAX_WAIT_BEFORE_SUCCESS)
for idx in range(self.TASKS_COUNT)
idx: random.randint(0, MAX_WAIT_BEFORE_SUCCESS)
for idx in range(TASKS_COUNT)
}
self._succeeded_tasks = [False for _ in range(self.TASKS_COUNT)]
self._succeeded_tasks = [False for _ in range(TASKS_COUNT)]

def workload_aways_failed(self, idx: int) -> int:
time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST)
time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST)
raise _RetryTaskMapTestErr

def workload_failed_and_then_succeed(self, idx: int) -> int:
time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST)
time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST)
if time.time() > self._start_time + self._success_wait_dict[idx]:
self._succeeded_tasks[idx] = True
return idx
raise _RetryTaskMapTestErr

def workload_succeed(self, idx: int) -> int:
time.sleep((self.TASKS_COUNT - random.randint(0, idx)) / self.WAIT_CONST)
time.sleep((TASKS_COUNT - random.randint(0, idx)) / WAIT_CONST)
self._succeeded_tasks[idx] = True
return idx

Expand All @@ -69,48 +79,62 @@ def _exit_on_exceed_max_count():
raise ValueError(f"{failure_count=} > {MAX_RETRY=}")

with retry_task_map.ThreadPoolExecutorWithRetry(
max_concurrent=self.MAX_CONCURRENT,
max_concurrent=MAX_CONCURRENT,
watchdog_func=_exit_on_exceed_max_count,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
) as executor:
with pytest.raises(retry_task_map.TasksEnsureFailed):
for _fut in executor.ensure_tasks(
self.workload_aways_failed, range(self.TASKS_COUNT)
self.workload_aways_failed, range(TASKS_COUNT)
):
if _fut.exception():
failure_count += 1

def test_retry_exceed_retry_limit(self):
MAX_TOTAL_RETRY = 200
failure_count = 0
with retry_task_map.ThreadPoolExecutorWithRetry(
max_concurrent=self.MAX_CONCURRENT, max_total_retry=200
max_concurrent=MAX_CONCURRENT,
max_total_retry=MAX_TOTAL_RETRY,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
) as executor:
with pytest.raises(retry_task_map.TasksEnsureFailed):
for _ in executor.ensure_tasks(
self.workload_aways_failed, range(self.TASKS_COUNT)
for _fut in executor.ensure_tasks(
self.workload_aways_failed, range(TASKS_COUNT)
):
pass
if _fut.exception():
failure_count += 1

assert failure_count >= MAX_TOTAL_RETRY

def test_retry_finally_succeeded(self):
count = 0
with retry_task_map.ThreadPoolExecutorWithRetry(
max_concurrent=self.MAX_CONCURRENT
max_concurrent=MAX_CONCURRENT,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
) as executor:
for _fut in executor.ensure_tasks(
self.workload_failed_and_then_succeed, range(self.TASKS_COUNT)
self.workload_failed_and_then_succeed, range(TASKS_COUNT)
):
if not _fut.exception():
count += 1
assert all(self._succeeded_tasks)
assert self.TASKS_COUNT == count
assert TASKS_COUNT == count

def test_succeeded_in_one_try(self):
count = 0
with retry_task_map.ThreadPoolExecutorWithRetry(
max_concurrent=self.MAX_CONCURRENT
max_concurrent=MAX_CONCURRENT,
initializer=_thread_initializer,
initargs=(THREAD_INIT_MSG,),
) as executor:
for _fut in executor.ensure_tasks(
self.workload_succeed, range(self.TASKS_COUNT)
self.workload_succeed, range(TASKS_COUNT)
):
if not _fut.exception():
count += 1
assert all(self._succeeded_tasks)
assert self.TASKS_COUNT == count
assert TASKS_COUNT == count

0 comments on commit 4da2b8e

Please sign in to comment.