Skip to content

Commit

Permalink
fix initial issues with the celery tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab committed May 31, 2024
1 parent 670e365 commit a562b48
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 29 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ jobs:
poetry install
- name: start up services
run: makim tests.setup
run: |
sugar build
sugar ext restart --options -d
- name: Run tests
run: makim tests.unit
Expand All @@ -64,3 +66,7 @@ jobs:
run: |
pre-commit install
makim tests.linter
- name: teardown services
run: |
sugar ext stop
15 changes: 9 additions & 6 deletions src/retsu/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from __future__ import annotations

from typing import Optional
from typing import Any, Optional

import celery

from celery import chain, chord
from celery import chain, chord, group
from public import public

from retsu.core import ParallelTask, SerialTask
Expand All @@ -15,7 +15,7 @@
class CeleryTask:
"""Celery Task class."""

def task(self, *args, task_id: str, **kwargs) -> None: # type: ignore
def task(self, *args, task_id: str, **kwargs) -> Any: # type: ignore
"""Define the task to be executed."""
chord_tasks, chord_callback = self.get_chord_tasks(
*args, task_id=task_id, **kwargs
Expand All @@ -27,19 +27,22 @@ def task(self, *args, task_id: str, **kwargs) -> None: # type: ignore
if chord_callback:
workflow_chord = chord(chord_tasks, chord_callback)
else:
workflow_chord = chord(chord_tasks)
workflow_chord = group(chord_tasks)
promise_chord = workflow_chord.apply_async()

if chain_tasks:
workflow_chain = chain(chord_tasks)
promise_chain = workflow_chain.apply_async()

# wait for the tasks
results: list[Any] = []
if chord_tasks:
promise_chord.get()
results.extend(promise_chord.get())

if chain_tasks:
promise_chain.get()
results.append(promise_chain.get())

return results

def get_chord_tasks( # type: ignore
self, *args, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion src/retsu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def task(self, *args, task_id: str, **kwargs) -> Any: # type: ignore
"""Define the task to be executed."""
raise Exception("`task` not implemented yet.")

def prepare_task(self, data: Any) -> None:
def prepare_task(self, data: dict[str, Any]) -> None:
"""Call the task with the necessary arguments."""
task_id = data.pop("task_id")
self.result.metadata.update(task_id, "status", "running")
Expand Down
22 changes: 11 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,14 @@
def setup() -> Generator[None, None, None]:
"""Set up the services needed by the tests."""
try:
# Run the `sugar build` command
subprocess.run(["sugar", "build"], check=True)

# Run the `sugar ext restart --options -d` command
subprocess.run(
["sugar", "ext", "restart", "--options", "-d"], check=True
)

# Sleep for 5 seconds
time.sleep(5)
# # Run the `sugar build` command
# subprocess.run(["sugar", "build"], check=True)
# # Run the `sugar ext restart --options -d` command
# subprocess.run(
# ["sugar", "ext", "restart", "--options", "-d"], check=True
# )
# # Sleep for 5 seconds
# time.sleep(5)

# Change directory to `tests/`
os.chdir("tests/")
Expand All @@ -34,6 +32,8 @@ def setup() -> Generator[None, None, None]:
["celery", "-A", "celery_tasks", "worker", "--loglevel=debug"]
)

time.sleep(5)

# Change directory back to the original
os.chdir("..")

Expand All @@ -43,4 +43,4 @@ def setup() -> Generator[None, None, None]:
# Teardown: Terminate the Celery worker
celery_process.terminate()
celery_process.wait()
subprocess.run(["sugar", "ext", "stop"], check=True)
# subprocess.run(["sugar", "ext", "stop"], check=True)
26 changes: 16 additions & 10 deletions tests/test_task_celery_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@

from __future__ import annotations

from typing import Generator
from typing import Generator, Optional

import celery
import pytest

from retsu import Task
from retsu.celery import SerialCeleryTask

from .celery_tasks import task_sum
from .celery_tasks import task_sleep, task_sum


class MyResultTask(SerialCeleryTask):
"""Task for the test."""

def get_chord_tasks(self, *args, **kwargs) -> list[celery.Signature]:
def get_chord_tasks( # type: ignore
self, *args, **kwargs
) -> tuple[list[celery.Signature], Optional[celery.Signature]]:
"""Define the list of tasks for celery chord."""
x = kwargs.get("x")
y = kwargs.get("y")
Expand All @@ -28,12 +32,14 @@ def get_chord_tasks(self, *args, **kwargs) -> list[celery.Signature]:
class MyTimestampTask(SerialCeleryTask):
"""Task for the test."""

def get_chord_tasks(self, *args, **kwargs) -> list[celery.Signature]:
def get_chord_tasks( # type: ignore
self, *args, **kwargs
) -> tuple[list[celery.Signature], Optional[celery.Signature]]:
"""Define the list of tasks for celery chord."""
seconds = kwargs.get("seconds")
task_id = kwargs.get("task_id")
return (
[task_sum.s(x, y, task_id, task_id)],
[task_sleep.s(seconds, task_id, task_id)],
None,
)

Expand All @@ -56,8 +62,8 @@ def task_timestamp() -> Generator[Task, None, None]:
task.stop()


class TestSerialTask:
"""TestSerialTask."""
class TestSerialCeleryTask:
"""TestSerialCeleryTask."""

def test_serial_result(self, task_result: Task) -> None:
"""Run simple test for a serial task."""
Expand All @@ -66,11 +72,11 @@ def test_serial_result(self, task_result: Task) -> None:
task = task_result

for i in range(10):
task_id = task.request(a=i, b=i)
task_id = task.request(x=i, y=i)
results[task_id] = i + i

for task_id, expected in results.items():
result = task.result.get(task_id, timeout=2)
result = task.result.get(task_id, timeout=10)
assert (
result == expected
), f"Expected Result: {expected}, Actual Result: {result}"
Expand All @@ -82,7 +88,7 @@ def test_serial_timestamp(self, task_timestamp: Task) -> None:
task = task_timestamp

for sleep_time in range(5, 1, -1):
task_id = task.request(sleep=sleep_time)
task_id = task.request(seconds=sleep_time)
results.append((task_id, 0))

# gather results
Expand Down

0 comments on commit a562b48

Please sign in to comment.