From 41c6fb37a19c376b010aae13c22f1ac43cbd80fe Mon Sep 17 00:00:00 2001 From: Ivan Ogasawara Date: Sat, 17 Aug 2024 23:59:04 -0400 Subject: [PATCH] feat: Add wrap-up decorator for managing celery tasks (#20) --- src/retsu/celery.py | 82 +++++++++++++++++++++++++++- src/retsu/core.py | 86 ++++++++++++++++++++++++++++- tests/celery_tasks.py | 43 +++++++++++++-- tests/conftest.py | 58 +++++++++----------- tests/test_task_celery_wrapup.py | 93 ++++++++++++++++++++++++++++++++ 5 files changed, 324 insertions(+), 38 deletions(-) create mode 100644 tests/test_task_celery_wrapup.py diff --git a/src/retsu/celery.py b/src/retsu/celery.py index 4d07b1b..689ca77 100644 --- a/src/retsu/celery.py +++ b/src/retsu/celery.py @@ -2,14 +2,25 @@ from __future__ import annotations -from typing import Any, Optional +import logging +import time +import uuid + +from functools import wraps +from typing import Any, Callable, Optional import celery +import redis from celery import chain, chord, group from public import public -from retsu.core import MultiProcess, SingleProcess +from retsu.core import ( + MultiProcess, + RandomSemaphoreManager, + SequenceSemaphoreManager, + SingleProcess, +) class CeleryProcess: @@ -121,3 +132,70 @@ class SingleCeleryProcess(CeleryProcess, SingleProcess): """Single Process for Celery.""" ... + + +def limit_random_concurrent_tasks( + max_concurrent_tasks: int, + redis_client: redis.Redis, +) -> Callable[[Any], Any]: + """Limit the number of concurrent Celery tasks.""" + + def decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]: + semaphore_manager = RandomSemaphoreManager( + key=f"celery_task_semaphore_random_{func.__name__}", + max_concurrent_tasks=max_concurrent_tasks, + redis_client=redis_client, + ) + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + # Acquire semaphore slot + acquired = semaphore_manager.acquire() + if not acquired: + logging.info(f"Task {func.__name__} is waiting for a slot...") + while not acquired: + time.sleep(0.01) # Polling interval + acquired = semaphore_manager.acquire() + + try: + result = func(*args, **kwargs) + return result + finally: + # Release semaphore slot + semaphore_manager.release() + + return wrapper + + return decorator + + +def limit_sequence_concurrent_tasks( + max_concurrent_tasks: int, + redis_client: redis.Redis, +) -> Callable[[Any], Any]: + """Limit the number of concurrent Celery tasks and maintain FIFO order.""" + + def decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]: + semaphore_manager = SequenceSemaphoreManager( + key=f"celery_task_semaphore_sequence_{func.__name__}", + max_concurrent_tasks=max_concurrent_tasks, + redis_client=redis_client, + ) + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + task_id = str(uuid.uuid4()) # Unique identifier for each task + + # Acquire semaphore slot with FIFO order + acquired = semaphore_manager.acquire(task_id) + if acquired: + try: + result = func(*args, **kwargs) + return result + finally: + # Release semaphore slot + semaphore_manager.release() + + return wrapper + + return decorator diff --git a/src/retsu/core.py b/src/retsu/core.py index 8d47bf3..cc83c11 100644 --- a/src/retsu/core.py +++ b/src/retsu/core.py @@ -4,11 +4,12 @@ import logging import multiprocessing as mp +import time import warnings from abc import abstractmethod from datetime import datetime -from typing import Any, Optional +from typing import Any, Optional, cast from uuid import uuid4 import redis @@ -177,3 +178,86 @@ def stop(self) -> None: for task_name, process in self.tasks.items(): process.stop() + + +class RandomSemaphoreManager: + """Manages a semaphore using Redis to limit concurrent tasks.""" + + def __init__( + self, key: str, max_concurrent_tasks: int, redis_client: redis.Redis + ): + self.key: str = key + self.max_concurrent_tasks: int = max_concurrent_tasks + self.redis_client: redis.Redis = redis_client + + def acquire(self) -> bool: + """Try to acquire a semaphore slot.""" + current_count_tmp = self.redis_client.get(self.key) + current_count = 0 + + if current_count_tmp is None: + self.redis_client.set(self.key, 0) + else: + # note: Argument 1 to "int" has incompatible type + # "Union[Awaitable[Any], Any]"; expected + # "Union[str, Buffer, SupportsInt, SupportsIndex, SupportsTrunc]" + current_count = int(current_count_tmp) # type: ignore + + if current_count < self.max_concurrent_tasks: + self.redis_client.incr(self.key) + return True + return False + + def release(self) -> None: + """Release a semaphore slot.""" + self.redis_client.decr(self.key) + + +class SequenceSemaphoreManager: + """Manages a semaphore using Redis to limit concurrent tasks.""" + + def __init__( + self, key: str, max_concurrent_tasks: int, redis_client: redis.Redis + ): + self.key: str = key + self.max_concurrent_tasks: int = max_concurrent_tasks + self.redis_client: redis.Redis = redis_client + + def acquire(self, task_id: str) -> bool: + """Try to acquire a semaphore slot and ensure FIFO order.""" + task_bid = task_id.encode("utf8") + queue_name = f"{self.key}_queue" + + # Add task to the queue + self.redis_client.rpush(queue_name, task_id) + + while True: + # Get the list of current tasks in the queue + queue_tasks = self.redis_client.lrange(queue_name, 0, -1) + count_tmp = cast(bytes, self.redis_client.get(self.key) or b"0") + current_count = int(count_tmp) + + # Check if the task is in the first `max_concurrent_tasks` + # in the queue + # mypy: Item "Awaitable[List[Any]]" of + # "Union[Awaitable[List[Any]], List[Any]]" + # has no attribute "index" + task_position = queue_tasks.index(task_bid) # type: ignore + + if ( + task_position < self.max_concurrent_tasks + and current_count < self.max_concurrent_tasks + ): + # If a slot is available and the task is within the + # allowed concurrent limit + self.redis_client.incr(self.key) + return True + + # If no slot is available or task is not in the allowed + # concurrent tasks, keep waiting + time.sleep(0.1) + + def release(self) -> None: + """Release a semaphore slot and remove the task from the queue.""" + self.redis_client.decr(self.key) + self.redis_client.lpop(f"{self.key}_queue") diff --git a/tests/celery_tasks.py b/tests/celery_tasks.py index f5ce612..a6213df 100644 --- a/tests/celery_tasks.py +++ b/tests/celery_tasks.py @@ -4,6 +4,7 @@ import os import sys +import time from datetime import datetime from time import sleep @@ -11,6 +12,10 @@ import redis from celery import Celery +from retsu.celery import ( + limit_random_concurrent_tasks, + limit_sequence_concurrent_tasks, +) redis_host: str = os.getenv("RETSU_REDIS_HOST", "localhost") redis_port: int = int(os.getenv("RETSU_REDIS_PORT", 6379)) @@ -33,10 +38,10 @@ worker_task_log_format=( f"{LOG_FORMAT_PREFIX} %(task_name)s[%(task_id)s]: %(message)s" ), - task_annotations={"*": {"rate_limit": "10/s"}}, + # task_annotations={"*": {"rate_limit": "10/s"}}, task_track_started=True, - task_time_limit=30 * 60, - task_soft_time_limit=30 * 60, + # task_time_limit=30 * 60, + # task_soft_time_limit=30 * 60, worker_redirect_stdouts_level="DEBUG", ) @@ -68,3 +73,35 @@ def task_sleep(seconds: int, task_id: str) -> int: """Sum two numbers, x and y, and sleep the same amount of the sum.""" sleep(seconds) return int(datetime.now().timestamp()) + + +@app.task # type: ignore +@limit_random_concurrent_tasks( + max_concurrent_tasks=2, redis_client=redis_client +) +def task_random_get_time( + request_id: int, start_time: float +) -> tuple[int, float]: + """Limit simple task max concurrent.""" + print( + f"[Random] Started task {request_id} after:", + time.time() - start_time, + ) + sleep(1) + return request_id, time.time() + + +@app.task # type: ignore +@limit_sequence_concurrent_tasks( + max_concurrent_tasks=2, redis_client=redis_client +) +def task_sequence_get_time( + request_id: int, start_time: float +) -> tuple[int, float]: + """Limit simple task max concurrent.""" + print( + f"[Sequence] Started task {request_id} after:", + time.time() - start_time, + ) + sleep(1) + return request_id, time.time() diff --git a/tests/conftest.py b/tests/conftest.py index 194602e..aa000bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,56 +2,50 @@ from __future__ import annotations -import subprocess -import time +import logging -from typing import Generator +from typing import Any, Generator import pytest import redis +from celery.contrib.testing.worker import start_worker from retsu.queues import get_redis_queue_config +from tests.celery_tasks import app as celery_app + def redis_flush() -> None: """Wipe-out redis database.""" + logging.info("Wiping-out redis database.") r = redis.Redis(**get_redis_queue_config()) # type: ignore r.flushdb() +@pytest.fixture(scope="session") +def celery_worker_parameters() -> dict[str, Any]: + """Parameters for the Celery worker.""" + return { + "loglevel": "debug", # Set log level + "concurrency": 4, # Number of concurrent workers + "perform_ping_check": False, + "pool": "prefork", + } + + @pytest.fixture(autouse=True, scope="session") -def setup() -> Generator[None, None, None]: +def setup( + celery_worker_parameters: dict[str, Any], +) -> Generator[None, None, None]: """Set up the services needed by the tests.""" try: - # # Run the `sugar build` command - # subprocess.run(["sugar", "build"], check=True) - # # Run the `sugar ext restart --options -d` command - # subprocess.run( - # ["sugar", "ext", "restart", "--options", "-d"], check=True - # ) - # # Sleep for 5 seconds - # time.sleep(5) - - # Clean Redis queues + logging.info("Clean Redis queues") redis_flush() - # Start the Celery worker - celery_process = subprocess.Popen( - [ - "celery", - "-A", - "tests.celery_tasks", - "worker", - "--loglevel=debug", - ], - ) - - time.sleep(5) - - yield + logging.info("Start the Celery worker") + with start_worker(celery_app, **celery_worker_parameters) as worker: + # Ensure worker is up and running + yield worker # Now you can use this worker in your tests finally: - # Teardown: Terminate the Celery worker - celery_process.terminate() - celery_process.wait() - # subprocess.run(["sugar", "ext", "stop"], check=True) + pass diff --git a/tests/test_task_celery_wrapup.py b/tests/test_task_celery_wrapup.py new file mode 100644 index 0000000..34c2ce5 --- /dev/null +++ b/tests/test_task_celery_wrapup.py @@ -0,0 +1,93 @@ +"""Test celery tasks with a wrapup.""" + +from __future__ import annotations + +import time + +from celery import Task + +from .celery_tasks import ( + task_random_get_time, + task_sequence_get_time, +) + + +def test_task_random_get_time() -> None: + """Test task_random_get_time.""" + results: dict[int, float] = {} + tasks: list[Task] = [] + start_time = time.time() + + for i in range(10): + task_promise = task_random_get_time.s( + request_id=i, start_time=start_time + ) + tasks.append(task_promise.apply_async()) + + for i in range(10): + task = tasks[i] + task_id, result = task.get(timeout=10) + assert i == task_id + results[task_id] = result + + previous_time = results[0] + previous_id = 0 + tol = 0.2 + for i in range(10): + current_time = results[i] + diff = abs(current_time - previous_time) + # print( + # f"task {previous_id}-{i}, diff: {diff}, " + # f"expected: {diff_expected}" + # ) + assert diff - tol < 5, f"[EE] Task {previous_id}-{i}" + previous_time = current_time + previous_id = i + + +def test_task_sequence_get_time() -> None: + """Test task_sequence_get_time.""" + results: dict[int, float] = {} + tasks: list[Task] = [] + start_time = time.time() + + for i in range(10): + task_promise = task_sequence_get_time.s( + request_id=i, start_time=start_time + ) + tasks.append(task_promise.apply_async()) + + for i in range(10): + task = tasks[i] + task_id, result = task.get(timeout=10) + assert i == task_id + results[task_id] = result + + diffs = ( + 0, # 0 -> 0 + 0, # 0 -> 1 + 1, # 1 -> 2 + 0, # 2 -> 3 + 1, # 3 -> 4 + 0, # 4 -> 5 + 1, # 5 -> 6 + 0, # 6 -> 7 + 1, # 7 -> 8 + 0, # 8 -> 9 + ) + + previous_time = results[0] + previous_id = 0 + tol = 0.2 + for i in range(10): + current_time = results[i] + diff = current_time - previous_time + diff_expected = diffs[i] + # print( + # f"task {previous_id}-{i}, diff: {diff}, " + # f"expected: {diff_expected}" + # ) + assert diff >= diff_expected - tol, f"[EE] Task {previous_id}-{i}" + assert diff <= diff_expected + tol, f"[EE] Task {previous_id}-{i}" + previous_time = current_time + previous_id = i