Skip to content

Commit

Permalink
Added on_send_error middleware hook
Browse files Browse the repository at this point in the history
  • Loading branch information
hawang-wish committed Aug 29, 2024
1 parent 739be3c commit bb84f2b
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 3 deletions.
21 changes: 20 additions & 1 deletion taskiq/abc/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

if TYPE_CHECKING: # pragma: no cover # pragma: no cover
from taskiq.abc.broker import AsyncBroker
from taskiq.message import TaskiqMessage
from taskiq.message import BrokerMessage, TaskiqMessage
from taskiq.result import TaskiqResult


Expand Down Expand Up @@ -126,3 +126,22 @@ def on_error(
:param result: returned value.
:param exception: found exception.
"""

def on_send_error(
self,
message: "TaskiqMessage",
broker_message: "BrokerMessage",
exception: BaseException,
) -> "Union[Union[bool, None], Coroutine[Any, Any, Union[bool, None]]]":
"""
This function is called when exception is raised while sending a message.
In most cases, it would be a connection issue from the broker.
Any exceptions occurred by broker's formatter will not trigger this.
:param message: the sending TaskiqMessage (not BrokerMessage)
:param broker_message: the sending BrokerMessage (not TaskiqMessage)
:param exception: exception, not yet wrapped with SendTaskError
:return: True if the error should be omitted, False or None otherwise.
"""
22 changes: 20 additions & 2 deletions taskiq/kicker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,28 @@ async def kiq(
for middleware in self.broker.middlewares:
if middleware.__class__.pre_send != TaskiqMiddleware.pre_send:
message = await maybe_awaitable(middleware.pre_send(message))

broker_message = self.broker.formatter.dumps(message)
try:
await self.broker.kick(self.broker.formatter.dumps(message))
await self.broker.kick(broker_message)
except Exception as exc:
raise SendTaskError from exc
omitting = False
for middleware in reversed(self.broker.middlewares):
if middleware.__class__.on_send_error != TaskiqMiddleware.on_send_error:
omitting = (
bool(
await maybe_awaitable(
middleware.on_send_error(
message,
broker_message,
exc,
),
),
)
or omitting
)
if not omitting:
raise SendTaskError from exc

for middleware in self.broker.middlewares:
if middleware.__class__.post_send != TaskiqMiddleware.post_send:
Expand Down
96 changes: 96 additions & 0 deletions tests/middlewares/test_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import asyncio

import pytest

from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.brokers.inmemory_broker import InMemoryBroker
from taskiq.exceptions import SendTaskError
from taskiq.message import BrokerMessage, TaskiqMessage


@pytest.mark.anyio
async def test_on_send_error() -> None:
caught = []

class _TestMiddleware(TaskiqMiddleware):
def on_send_error(
self,
message: "TaskiqMessage",
broker_message: "BrokerMessage",
exception: BaseException,
) -> bool:
caught.append(1)
return True

broker = InMemoryBroker().with_middlewares(_TestMiddleware())

broker.kick = lambda *args, **kwargs: (_ for _ in ()).throw(Exception("test")) # type: ignore

await broker.startup()
await broker.task(lambda: None).kiq()
await broker.shutdown()

assert caught == [1]


@pytest.mark.anyio
async def test_on_send_error_raise() -> None:
caught = []

class _TestMiddleware(TaskiqMiddleware):
def on_send_error(
self,
message: "TaskiqMessage",
broker_message: "BrokerMessage",
exception: BaseException,
) -> None:
caught.append(0)

broker = InMemoryBroker().with_middlewares(_TestMiddleware())

broker.kick = lambda *args, **kwargs: (_ for _ in ()).throw(Exception("test")) # type: ignore

await broker.startup()

with pytest.raises(SendTaskError):
await broker.task(lambda: None).kiq()

await broker.shutdown()

assert caught == [0]


@pytest.mark.anyio
async def test_on_send_error_inverted() -> None:
caught = []

class _TestMiddleware1(TaskiqMiddleware):
def on_send_error(
self,
message: "TaskiqMessage",
broker_message: "BrokerMessage",
exception: BaseException,
) -> bool:
caught.append(1)
return True

class _TestMiddleware2(TaskiqMiddleware):
async def on_send_error(
self,
message: "TaskiqMessage",
broker_message: "BrokerMessage",
exception: BaseException,
) -> bool:
await asyncio.sleep(0)
caught.append(2)
return True

broker = InMemoryBroker().with_middlewares(_TestMiddleware1(), _TestMiddleware2())

broker.kick = lambda *args, **kwargs: (_ for _ in ()).throw(Exception("test")) # type: ignore

await broker.startup()
await broker.task(lambda: None).kiq()
await broker.shutdown()

assert caught == [2, 1]

0 comments on commit bb84f2b

Please sign in to comment.