diff --git a/taskiq/message.py b/taskiq/message.py index 675f7cf..787cb25 100644 --- a/taskiq/message.py +++ b/taskiq/message.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from pydantic import BaseModel @@ -20,6 +20,7 @@ class TaskiqMessage(BaseModel): labels_types: Optional[Dict[str, int]] = None args: List[Any] kwargs: Dict[str, Any] + ack: Optional[Callable[[], Union[None, Awaitable[None]]]] = None def parse_labels(self) -> None: """ diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index c9b1d66..adef89f 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -101,9 +101,15 @@ async def callback( # noqa: C901, PLR0912 :param raise_err: raise an error if cannot save result in result_backend. """ - message_data = message.data if isinstance(message, AckableMessage) else message + message_data, message_ack = ( + (message.data, message.ack) + if isinstance(message, AckableMessage) + else (message, None) + ) try: taskiq_msg = self.broker.formatter.loads(message=message_data) + if message_ack: + taskiq_msg.ack = message_ack taskiq_msg.parse_labels() except Exception as exc: logger.warning( @@ -143,7 +149,7 @@ async def callback( # noqa: C901, PLR0912 message, AckableMessage, ): - await maybe_awaitable(message.ack()) + await maybe_awaitable(taskiq_msg.ack()) # type: ignore result = await self.run_task( target=task.original_func, @@ -154,7 +160,7 @@ async def callback( # noqa: C901, PLR0912 message, AckableMessage, ): - await maybe_awaitable(message.ack()) + await maybe_awaitable(taskiq_msg.ack()) # type: ignore for middleware in self.broker.middlewares: if middleware.__class__.post_execute != TaskiqMiddleware.post_execute: @@ -181,7 +187,7 @@ async def callback( # noqa: C901, PLR0912 message, AckableMessage, ): - await maybe_awaitable(message.ack()) + await maybe_awaitable(taskiq_msg.ack()) # type: ignore async def run_task( # noqa: C901, PLR0912, PLR0915 self, diff --git a/tests/formatters/test_json_formatter.py b/tests/formatters/test_json_formatter.py index 17a3718..0823205 100644 --- a/tests/formatters/test_json_formatter.py +++ b/tests/formatters/test_json_formatter.py @@ -23,7 +23,8 @@ async def test_json_dumps() -> None: b'{"task_id":"task-id","task_name":"task.name",' b'"labels":{"label1":1,"label2":"text"},' b'"labels_types":null,' - b'"args":[1,"a"],"kwargs":{"p1":"v1"}}' + b'"args":[1,"a"],"kwargs":{"p1":"v1"},' + b'"ack":null}' ), labels={"label1": 1, "label2": "text"}, ) diff --git a/tests/formatters/test_proxy_formatter.py b/tests/formatters/test_proxy_formatter.py index 8d583f1..48ce6c2 100644 --- a/tests/formatters/test_proxy_formatter.py +++ b/tests/formatters/test_proxy_formatter.py @@ -22,7 +22,8 @@ async def test_proxy_dumps() -> None: b'{"task_id": "task-id", "task_name": "task.name", ' b'"labels": {"label1": 1, "label2": "text"}, ' b'"labels_types": null, ' - b'"args": [1, "a"], "kwargs": {"p1": "v1"}}' + b'"args": [1, "a"], "kwargs": {"p1": "v1"}, ' + b'"ack": null}' ), labels={"label1": 1, "label2": "text"}, )