diff --git a/example/tasks/config.py b/example/tasks/config.py index 6f06094..311e5eb 100644 --- a/example/tasks/config.py +++ b/example/tasks/config.py @@ -1,6 +1,7 @@ """Configuration for Celery app.""" import os +import sys import redis @@ -47,4 +48,4 @@ print("Redis connection is working.") except redis.ConnectionError as e: print(f"Failed to connect to Redis: {e}") - exit(1) + sys.exit(1) diff --git a/src/retsu/celery.py b/src/retsu/celery.py index 8a6e90b..1508f02 100644 --- a/src/retsu/celery.py +++ b/src/retsu/celery.py @@ -7,6 +7,7 @@ import celery from celery import chain, chord +from public import public from retsu.core import ParallelTask, SerialTask @@ -63,12 +64,14 @@ def get_chain_tasks( # type: ignore return chain_tasks +@public class ParallelCeleryTask(CeleryTask, ParallelTask): """Parallel Task for Celery.""" ... +@public class SerialCeleryTask(CeleryTask, SerialTask): """Serial Task for Celery.""" diff --git a/tests/celery_tasks.py b/tests/celery_tasks.py index 15a8a25..fb41d9a 100644 --- a/tests/celery_tasks.py +++ b/tests/celery_tasks.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +import sys from datetime import datetime from time import sleep @@ -52,7 +53,7 @@ print("Redis connection is working.") except redis.ConnectionError as e: print(f"Failed to connect to Redis: {e}") - exit(1) + sys.exit(1) @app.task # type: ignore diff --git a/tests/test_task_celery_serial.py b/tests/test_task_celery_serial.py index b211439..509602a 100644 --- a/tests/test_task_celery_serial.py +++ b/tests/test_task_celery_serial.py @@ -2,34 +2,40 @@ from __future__ import annotations -from datetime import datetime -from time import sleep -from typing import Any, Generator +from typing import Generator import pytest -from retsu import SerialTask, Task +from retsu.celery import SerialCeleryTask +from .celery_tasks import task_sum -class MyResultTask(SerialTask): + +class MyResultTask(SerialCeleryTask): """Task for the test.""" - def task(self, *args, task_id: str, **kwargs) -> Any: # type: ignore - """Return the sum of the given 2 numbers.""" - a = kwargs.pop("a", 0) - b = kwargs.pop("b", 0) - result = a + b - return result + def get_chord_tasks(self, *args, **kwargs) -> list[celery.Signature]: + """Define the list of tasks for celery chord.""" + x = kwargs.get("x") + y = kwargs.get("y") + task_id = kwargs.get("task_id") + return ( + [task_sum.s(x, y, task_id)], + None, + ) -class MyTimestampTask(SerialTask): +class MyTimestampTask(SerialCeleryTask): """Task for the test.""" - def task(self, *args, task_id: str, **kwargs) -> Any: # type: ignore - """Sleep the given seconds, and return the current timestamp.""" - sleep_time = kwargs.pop("sleep", 0) - sleep(sleep_time) - return datetime.now().timestamp() + def get_chord_tasks(self, *args, **kwargs) -> list[celery.Signature]: + """Define the list of tasks for celery chord.""" + seconds = kwargs.get("seconds") + task_id = kwargs.get("task_id") + return ( + [task_sum.s(x, y, task_id, task_id)], + None, + ) @pytest.fixture