From 7ce37d12990029d2d2d02446b4fd5fc7040205f0 Mon Sep 17 00:00:00 2001 From: Ivan Ogasawara Date: Sat, 17 Aug 2024 14:41:22 -0400 Subject: [PATCH] feat: Add initial support for async processes (#19) --- poetry.lock | 33 +++++- pyproject.toml | 5 +- src/retsu/asyncio/__init__.py | 1 + src/retsu/asyncio/celery.py | 90 +++++++++++++++ src/retsu/asyncio/core.py | 107 ++++++++++++++++++ src/retsu/asyncio/queues.py | 60 ++++++++++ src/retsu/asyncio/results.py | 192 ++++++++++++++++++++++++++++++++ src/retsu/queues.py | 11 +- tests/test_task_celery_async.py | 99 ++++++++++++++++ 9 files changed, 595 insertions(+), 3 deletions(-) create mode 100644 src/retsu/asyncio/__init__.py create mode 100644 src/retsu/asyncio/celery.py create mode 100644 src/retsu/asyncio/core.py create mode 100644 src/retsu/asyncio/queues.py create mode 100644 src/retsu/asyncio/results.py create mode 100644 tests/test_task_celery_async.py diff --git a/poetry.lock b/poetry.lock index 23bca34..8c409be 100644 --- a/poetry.lock +++ b/poetry.lock @@ -68,6 +68,19 @@ files = [ {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, ] +[[package]] +name = "asyncio" +version = "3.4.3" +description = "reference implementation of PEP 3156" +optional = false +python-versions = "*" +files = [ + {file = "asyncio-3.4.3-cp33-none-win32.whl", hash = "sha256:b62c9157d36187eca799c378e572c969f0da87cd5fc42ca372d92cdb06e7e1de"}, + {file = "asyncio-3.4.3-cp33-none-win_amd64.whl", hash = "sha256:c46a87b48213d7464f22d9a497b9eef8c1928b68320a2fa94240f969f6fec08c"}, + {file = "asyncio-3.4.3-py3-none-any.whl", hash = "sha256:c4d18b22701821de07bd6aea8b53d21449ec0ec5680645e5317062ea21817d2d"}, + {file = "asyncio-3.4.3.tar.gz", hash = "sha256:83360ff8bc97980e4ff25c964c7bd3923d333d177aa4f7fb736b019f26c7cb41"}, +] + [[package]] name = "atpublic" version = "4.1.0" @@ -2229,6 +2242,24 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.23.8" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest_asyncio-0.23.8-py3-none-any.whl", hash = "sha256:50265d892689a5faefb84df80819d1ecef566eb3549cf915dfb33569359d1ce2"}, + {file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"}, +] + +[package.dependencies] +pytest = ">=7.0.0,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "5.0.0" @@ -3319,4 +3350,4 @@ django = ["django"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4" -content-hash = "f15de669df8a6ae75d497e5471d60371341bb1f49b133aa677bd7d1ec084529a" +content-hash = "d763e68645a083c84c5f5126172135fd276966de311e66b38407d70c68f7f7d2" diff --git a/pyproject.toml b/pyproject.toml index 45fb9e1..f9d24a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ celery = ">=5" redis = ">=5" django = { version = ">=3", optional = true } typing-extensions = ">=4.12.0" +asyncio = ">=3.4.3" [tool.poetry.extras] django = [ @@ -63,8 +64,10 @@ containers-sugar = "1.13.0" compose-go = "2.27.0" django = ">=3" django-stubs = ">=3" +pytest-asyncio = ">=0.23.8" [tool.pytest.ini_options] +# asyncio_mode = "auto" testpaths = [ "tests", ] @@ -117,4 +120,4 @@ ignore_missing_imports = true warn_unused_ignores = true warn_redundant_casts = true warn_unused_configs = true -exclude = ["examples/", "scripts/"] +exclude = ["examples/", "scripts/", "src/retsu/asyncio/"] diff --git a/src/retsu/asyncio/__init__.py b/src/retsu/asyncio/__init__.py new file mode 100644 index 0000000..5be292a --- /dev/null +++ b/src/retsu/asyncio/__init__.py @@ -0,0 +1 @@ +"""Async Retsu package.""" diff --git a/src/retsu/asyncio/celery.py b/src/retsu/asyncio/celery.py new file mode 100644 index 0000000..092ad95 --- /dev/null +++ b/src/retsu/asyncio/celery.py @@ -0,0 +1,90 @@ +"""Retsu tasks with celery.""" + +from __future__ import annotations + +from typing import Any, Optional + +import celery + +from celery import chain, chord, group +from public import public + +from retsu.asyncio.core import AsyncProcess + + +@public +class CeleryAsyncProcess(AsyncProcess): + """Async Celery Process class.""" + + async def process(self, *args, task_id: str, **kwargs) -> Any: + """Define the async process to be executed.""" + chord_tasks, chord_callback = await self.get_chord_tasks( + *args, + task_id=task_id, + **kwargs, + ) + group_tasks = await self.get_group_tasks( + *args, + task_id=task_id, + **kwargs, + ) + chain_tasks = await self.get_chain_tasks( + *args, + task_id=task_id, + **kwargs, + ) + + # Start the tasks asynchronously + results = [] + if chord_tasks: + workflow_chord = chord(chord_tasks, chord_callback) + promise_chord = workflow_chord.apply_async() + results.extend(promise_chord.get()) + + if group_tasks: + workflow_group = group(group_tasks) + promise_group = workflow_group.apply_async() + results.extend(promise_group.get()) + + if chain_tasks: + workflow_chain = chain(chain_tasks) + promise_chain = workflow_chain.apply_async() + results.extend(promise_chain.get()) + + return results + + async def get_chord_tasks( # type: ignore + self, *args, **kwargs + ) -> tuple[list[celery.Signature], Optional[celery.Signature]]: + """ + Run tasks with chord. + + Return + ------ + tuple: + list of tasks for the chord, and the task to be used as a callback + """ + chord_tasks: list[celery.Signature] = [] + callback_task = None + return (chord_tasks, callback_task) + + async def get_group_tasks( # type: ignore + self, *args, **kwargs + ) -> list[celery.Signature]: + """ + Run tasks with group. + + Return + ------ + tuple: + list of tasks for the chord, and the task to be used as a callback + """ + group_tasks: list[celery.Signature] = [] + return group_tasks + + async def get_chain_tasks( # type: ignore + self, *args, **kwargs + ) -> list[celery.Signature]: + """Run tasks with chain.""" + chain_tasks: list[celery.Signature] = [] + return chain_tasks diff --git a/src/retsu/asyncio/core.py b/src/retsu/asyncio/core.py new file mode 100644 index 0000000..751bc2b --- /dev/null +++ b/src/retsu/asyncio/core.py @@ -0,0 +1,107 @@ +"""Async core module.""" + +import asyncio +import logging + +from abc import abstractmethod +from datetime import datetime +from typing import Any +from uuid import uuid4 + +from redis import asyncio as aioredis + +from retsu.asyncio.queues import RedisRetsuAsyncQueue +from retsu.asyncio.results import ( + ResultProcessManagerAsync, + create_result_task_manager_async, +) +from retsu.core import Process +from retsu.queues import get_redis_queue_config + + +class AsyncProcess(Process): + """Main class for handling an async process.""" + + def __init__(self, workers: int = 1) -> None: + """Initialize an async process object.""" + _klass = self.__class__ + queue_in_name = f"{_klass.__module__}.{_klass.__qualname__}" + + self._client = aioredis.Redis(**get_redis_queue_config()) + self.active = True + self.workers = workers + self.result: ResultProcessManagerAsync = ( + create_result_task_manager_async() + ) + self.queue_in = RedisRetsuAsyncQueue(queue_in_name) + self.tasks = [] + + async def start(self) -> None: + """Start async tasks.""" + logging.info(f"Starting async process {self.__class__.__name__}") + for _ in range(self.workers): + task = asyncio.create_task(self.run()) + self.tasks.append(task) + + async def stop(self) -> None: + """Stop async tasks.""" + logging.info(f"Stopping async process {self.__class__.__name__}") + self.active = False + for task in self.tasks: + task.cancel() + try: + await ( + task + ) # Ensure the task is properly awaited before moving on + except asyncio.CancelledError: + logging.info(f"Task {task.get_name()} has been cancelled.") + + async def request(self, *args, **kwargs) -> str: # type: ignore + """Feed the queue with data from the request for the process.""" + task_id = uuid4().hex + metadata = { + "status": "starting", + "created_at": datetime.now().isoformat(), + "updated_at": datetime.now().isoformat(), + } + await self.result.create(task_id, metadata) # Ensure this is awaited + await self.queue_in.put( + { + "task_id": task_id, + "args": args, + "kwargs": kwargs, + } + ) + return task_id + + @abstractmethod + async def process(self, *args, task_id: str, **kwargs) -> Any: # type: ignore + """Define the async process to be executed.""" + raise Exception("`process` not implemented yet.") + + async def prepare_process(self, data: dict[str, Any]) -> None: + """Call the process with the necessary arguments.""" + task_id = data.pop("task_id") + await self.result.metadata.update(task_id, "status", "running") + result = await self.process( + *data["args"], + task_id=task_id, + **data["kwargs"], + ) + await self.result.save(task_id, result) + await self.result.metadata.update(task_id, "status", "completed") + + async def run(self) -> None: + """Run the async process with data from the queue.""" + while self.active: + try: + data = await self.queue_in.get() + await self.prepare_process(data) + except asyncio.CancelledError: + logging.info( + f"Task {asyncio.current_task().get_name()} cancelled." + ) + break # Break out of the loop if the task is canceled + except Exception as e: + logging.error(f"Error in process: {e}") + break # Break out of the loop on any other exceptions diff --git a/src/retsu/asyncio/queues.py b/src/retsu/asyncio/queues.py new file mode 100644 index 0000000..0ffcf44 --- /dev/null +++ b/src/retsu/asyncio/queues.py @@ -0,0 +1,60 @@ +"""Functions for handling queues and their configurations.""" + +from __future__ import annotations + +import asyncio +import pickle + +from abc import abstractmethod +from typing import Any + +from public import public +from redis import asyncio as aioredis + +from retsu.queues import BaseRetsuQueue, get_redis_queue_config + + +@public +class BaseRetsuAsyncQueue(BaseRetsuQueue): + """Base Queue class.""" + + def __init__(self, name: str) -> None: + """Initialize BaseRetsuQueue.""" + self.name = name + + @abstractmethod + async def put(self, data: Any) -> None: + """Put data into the end of the queue.""" + ... + + @abstractmethod + async def get(self) -> Any: + """Get the next data from the queue.""" + ... + + +@public +class RedisRetsuAsyncQueue(BaseRetsuQueue): + """Async RedisRetsuQueue class.""" + + def __init__(self, name: str) -> None: + """Initialize RedisRetsuQueue with async Redis client.""" + super().__init__(name) + self._client = aioredis.Redis( + **get_redis_queue_config(), # Async Redis client configuration + decode_responses=False, + ) + + async def put(self, data: Any) -> None: + """Put data into the end of the queue asynchronously.""" + await self._client.rpush(self.name, pickle.dumps(data)) + + async def get(self) -> Any: + """Get the next data from the queue asynchronously.""" + while True: + data = await self._client.lpop(self.name) + if data is None: + await asyncio.sleep(0.1) # Non-blocking sleep for 100ms + continue + + return pickle.loads(data) diff --git a/src/retsu/asyncio/results.py b/src/retsu/asyncio/results.py new file mode 100644 index 0000000..50b7c43 --- /dev/null +++ b/src/retsu/asyncio/results.py @@ -0,0 +1,192 @@ +"""Retsu results classes with async support.""" + +from __future__ import annotations + +import asyncio +import pickle + +from datetime import datetime +from typing import Any, Callable, Optional, cast + +try: + # Python 3.12+ + from typing import Unpack # type: ignore[attr-defined] +except ImportError: + # < Python 3.12 + from typing_extensions import Unpack + +from public import public +from redis import asyncio as aioredis # Import the asyncio version of redis + +from retsu.queues import get_redis_queue_config + + +class ProcessMetadataManagerAsync: + """Manage process metadata asynchronously.""" + + def __init__(self, client: aioredis.Redis): + """Initialize ProcessMetadataManagerAsync.""" + self.client = client + self.step = StepMetadataManagerAsync(self.client) + + async def get_all(self, task_id: str) -> dict[str, bytes]: + """Get the entire metadata for a given process asynchronously.""" + result = await self.client.hgetall(f"process:{task_id}:metadata") + return cast(dict[str, bytes], result) + + async def get(self, task_id: str, attribute: str) -> bytes: + """Get a specific metadata attr for a given process asynchronously.""" + result = await self.client.hget( + f"process:{task_id}:metadata", attribute + ) + return cast(bytes, result) + + async def create(self, task_id: str, metadata: dict[str, Any]) -> None: + """Create an initial metadata for given process asynchronously.""" + await self.client.hset(f"process:{task_id}:metadata", mapping=metadata) + + async def update(self, task_id: str, attribute: str, value: Any) -> None: + """Update the value of given attr for a given process in async.""" + await self.client.hset(f"process:{task_id}:metadata", attribute, value) + await self.client.hset( + f"process:{task_id}:metadata", + "updated_at", + datetime.now().isoformat(), + ) + + +class StepMetadataManagerAsync: + """Manage metadata for steps of a process asynchronously.""" + + def __init__(self, redis_client: aioredis.Redis): + """Initialize StepMetadataManagerAsync.""" + self.client = redis_client + + async def get_all(self, task_id: str, step_id: str) -> dict[str, bytes]: + """Get the metadata for a given process and step asynchronously.""" + result = await self.client.hgetall(f"process:{task_id}:step:{step_id}") + return cast(dict[str, bytes], result) + + async def get(self, task_id: str, step_id: str, attribute: str) -> bytes: + """Get the value of a given attr for a given process+step in async.""" + result = await self.client.hget( + f"process:{task_id}:step:{step_id}", attribute + ) + return cast(bytes, result) + + async def create( + self, task_id: str, step_id: str, metadata: dict[str, Any] + ) -> None: + """Create an initial metadata for given process+step in async.""" + await self.client.hset( + f"process:{task_id}:step:{step_id}", mapping=metadata + ) + + async def update( + self, task_id: str, step_id: str, attribute: str, value: Any + ) -> None: + """Update the value of given attr for a given process+step async.""" + if attribute == "status" and value not in ["started", "completed"]: + raise Exception("Status should be started or completed.") + + await self.client.hset( + f"process:{task_id}:step:{step_id}", attribute, value + ) + await self.client.hset( + f"process:{task_id}:step:{step_id}", + "updated_at", + datetime.now().isoformat(), + ) + + +@public +class ResultProcessManagerAsync: + """Manage the result and metadata from tasks asynchronously.""" + + def __init__( + self, host: str = "localhost", port: int = 6379, db: int = 0 + ) -> None: + """Initialize ResultProcessManagerAsync.""" + self.client = aioredis.Redis( + host=host, port=port, db=db, decode_responses=False + ) + self.metadata = ProcessMetadataManagerAsync(self.client) + + async def get(self, task_id: str, timeout: Optional[int] = None) -> Any: + """Get the result for a given process asynchronously.""" + time_step = 0.5 + if timeout: + timeout_countdown = float(timeout) + while await self.status(task_id) != "completed": + await asyncio.sleep(time_step) + timeout_countdown -= time_step + if timeout_countdown <= 0: + status = await self.status(task_id) + raise Exception( + "Timeout(get): Process result is not ready yet. " + f"Process status: {status}" + ) + + elif await self.status(task_id) != "completed": + status = await self.status(task_id) + raise Exception( + "Timeout(get): Process result is not ready yet. " + f"Process status: {status}" + ) + result = await self.metadata.get(task_id, "result") + return pickle.loads(result) if result else result + + async def load(self, task_id: str) -> dict[str, Any]: + """Load the whole metadata for a given process asynchronously.""" + return await self.metadata.get_all(task_id) + + async def create(self, task_id: str, metadata: dict[str, Any]) -> None: + """Create a new metadata for a given process asynchronously.""" + await self.metadata.create(task_id, metadata) + + async def save(self, task_id: str, result: Any) -> None: + """Save the result for a given process asynchronously.""" + await self.metadata.update(task_id, "result", pickle.dumps(result)) + + async def status(self, task_id: str) -> str: + """Get the status for a given process asynchronously.""" + status = await self.metadata.get(task_id, "status") + return status.decode("utf8") + + +@public +def create_result_task_manager_async() -> ResultProcessManagerAsync: + """Create a ResultProcessManagerAsync from the environment vars.""" + return ResultProcessManagerAsync(**get_redis_queue_config()) # type: ignore + + +@public +def track_step_async( + task_metadata: ProcessMetadataManagerAsync, +) -> Callable[..., Any]: + """Decorate a function with ProcessMetadataManagerAsync.""" + + def decorator(task_func: Callable[..., Any]) -> Callable[..., Any]: + """Return a decorator for the given process.""" + + async def wrapper( + *args: Unpack[Any], **kwargs: Unpack[dict[str, Any]] + ) -> Any: + """Wrap a function for registering the process metadata async.""" + task_id = kwargs["task_id"] + step_id = kwargs.get("step_id", task_func.__name__) + + step_metadata = task_metadata.step + + await step_metadata.update(task_id, step_id, "status", "started") + result = await task_func(*args, **kwargs) + await step_metadata.update(task_id, step_id, "status", "completed") + result_pickled = pickle.dumps(result) + await step_metadata.update( + task_id, step_id, "result", result_pickled + ) + return result + + return wrapper + + return decorator diff --git a/src/retsu/queues.py b/src/retsu/queues.py index 2d03f83..5f10f22 100644 --- a/src/retsu/queues.py +++ b/src/retsu/queues.py @@ -32,6 +32,15 @@ def __init__(self, name: str) -> None: """Initialize BaseRetsuQueue.""" self.name = name + +@public +class BaseRetsuRegularQueue(BaseRetsuQueue): + """Base Queue class.""" + + def __init__(self, name: str) -> None: + """Initialize BaseRetsuQueue.""" + self.name = name + @abstractmethod def put(self, data: Any) -> None: """Put data into the end of the queue.""" @@ -44,7 +53,7 @@ def get(self) -> Any: @public -class RedisRetsuQueue(BaseRetsuQueue): +class RedisRetsuQueue(BaseRetsuRegularQueue): """RedisRetsuQueue class.""" def __init__(self, name: str) -> None: diff --git a/tests/test_task_celery_async.py b/tests/test_task_celery_async.py new file mode 100644 index 0000000..fb47db6 --- /dev/null +++ b/tests/test_task_celery_async.py @@ -0,0 +1,99 @@ +"""Tests for retsu package.""" + +from __future__ import annotations + +import celery +import pytest + +from retsu.asyncio.celery import CeleryAsyncProcess +from retsu.asyncio.core import AsyncProcess + +from .celery_tasks import task_sleep, task_sum + +# note: async process is not fully implemented +pytest.skip(allow_module_level=True) + + +class MyResultTask(CeleryAsyncProcess): + """Async Process for the test.""" + + async def get_group_tasks(self, *args, **kwargs) -> list[celery.Signature]: + """Define the list of tasks for celery chord.""" + x = kwargs.get("x") + y = kwargs.get("y") + task_id = kwargs.get("task_id") + return [task_sum.s(x, y, task_id)] + + +class MyTimestampTask(CeleryAsyncProcess): + """Async Process for the test.""" + + async def get_group_tasks(self, *args, **kwargs) -> list[celery.Signature]: + """Define the list of tasks for celery chord.""" + seconds = kwargs.get("seconds") + task_id = kwargs.get("task_id") + return [task_sleep.s(seconds, task_id)] + + +@pytest.fixture +async def task_result() -> AsyncProcess: + """Create a fixture for MyResultTask.""" + process = MyResultTask(workers=2) + await process.start() + yield process + await process.stop() + + +@pytest.fixture +async def task_timestamp() -> AsyncProcess: + """Create a fixture for MyTimestampTask.""" + process = MyTimestampTask(workers=5) + await process.start() + yield process + await process.stop() + + +@pytest.mark.asyncio +class TestMultiCeleryAsyncProcess: + """TestMultiCeleryAsyncProcess.""" + + async def test_multi_async_result(self, task_result: AsyncProcess) -> None: + """Run simple test for a multi-process.""" + results: dict[str, int] = {} + + async for process in task_result: + for i in range(10): + task_id = await process.request(x=i, y=i) + results[task_id] = i + i + + for task_id, expected in results.items(): + result = await process.result.get(task_id, timeout=10) + assert ( + result[0] == expected + ), f"Expected Result: {expected}, Actual Result: {result}" + + async def test_multi_async_timestamp( + self, task_timestamp: AsyncProcess + ) -> None: + """Run simple test for a multi-process.""" + results: list[tuple[str, int]] = [] + + async for process in task_timestamp: + for sleep_time in range(5, 1, -1): + task_id = await process.request(seconds=sleep_time * 1.5) + results.append((task_id, 0)) + + # Gather results + for i, (task_id, _) in enumerate(results): + results[i] = ( + task_id, + await process.result.get(task_id, timeout=10), + ) + # Check results + previous_timestamp = results[0][1] + for _, current_timestamp in results[1:]: + assert current_timestamp < previous_timestamp, ( + f"Previous timestamp: {previous_timestamp}, " + f"Current timestamp: {current_timestamp}" + ) + previous_timestamp = current_timestamp