Skip to content

Commit

Permalink
Adds a dead-letter queue to memory subscriptions to avoid infinite lo…
Browse files Browse the repository at this point in the history
…ops (#16051)
  • Loading branch information
desertaxle authored Nov 19, 2024
1 parent 743c4ee commit 39b6028
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 9 deletions.
95 changes: 87 additions & 8 deletions src/prefect/server/utilities/messaging/memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import copy
from contextlib import asynccontextmanager
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from datetime import timedelta
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
Expand All @@ -12,15 +14,19 @@
TypeVar,
Union,
)
from uuid import uuid4

import anyio
from cachetools import TTLCache
from pydantic_core import to_json
from typing_extensions import Self

from prefect.logging import get_logger
from prefect.server.utilities.messaging import Cache as _Cache
from prefect.server.utilities.messaging import Consumer as _Consumer
from prefect.server.utilities.messaging import Message, MessageHandler, StopConsumer
from prefect.server.utilities.messaging import Publisher as _Publisher
from prefect.settings.context import get_current_settings

logger = get_logger(__name__)

Expand All @@ -29,29 +35,101 @@
class MemoryMessage:
data: Union[bytes, str]
attributes: Dict[str, Any]
retry_count: int = 0


class Subscription:
topic: "Topic"
_queue: asyncio.Queue
_retry: asyncio.Queue

def __init__(self, topic: "Topic") -> None:
"""
A subscription to a topic.
Messages are delivered to the subscription's queue and retried up to a
maximum number of times. If a message cannot be delivered after the maximum
number of retries it is moved to the dead letter queue.
The dead letter queue is a directory of JSON files containing the serialized
message.
Messages remain in the dead letter queue until they are removed manually.
Attributes:
topic: The topic that the subscription receives messages from.
max_retries: The maximum number of times a message will be retried for
this subscription.
dead_letter_queue_path: The path to the dead letter queue folder.
"""

def __init__(
self,
topic: "Topic",
max_retries: int = 3,
dead_letter_queue_path: Union[Path, str, None] = None,
) -> None:
self.topic = topic
self.max_retries = max_retries
self.dead_letter_queue_path = (
Path(dead_letter_queue_path)
if dead_letter_queue_path
else get_current_settings().home / "dlq"
)
self._queue = asyncio.Queue()
self._retry = asyncio.Queue()

async def deliver(self, message: MemoryMessage) -> None:
"""
Deliver a message to the subscription's queue.
Args:
message: The message to deliver.
"""
await self._queue.put(message)

async def retry(self, message: MemoryMessage) -> None:
await self._retry.put(message)
"""
Place a message back on the retry queue.
If the message has retried more than the maximum number of times it is
moved to the dead letter queue.
Args:
message: The message to retry.
"""
message.retry_count += 1
if message.retry_count > self.max_retries:
logger.warning(
"Message failed after %d retries and will be moved to the dead letter queue",
message.retry_count,
extra={"event_message": message},
)
await self.send_to_dead_letter_queue(message)
else:
await self._retry.put(message)

async def get(self) -> MemoryMessage:
"""
Get a message from the subscription's queue.
"""
if self._retry.qsize() > 0:
return await self._retry.get()
return await self._queue.get()

async def send_to_dead_letter_queue(self, message: MemoryMessage) -> None:
"""
Send a message to the dead letter queue.
The dead letter queue is a directory of JSON files containing the
serialized messages.
Args:
message: The message to send to the dead letter queue.
"""
self.dead_letter_queue_path.mkdir(parents=True, exist_ok=True)
try:
await anyio.Path(self.dead_letter_queue_path / uuid4().hex).write_bytes(
to_json(asdict(message))
)
except Exception as e:
logger.warning("Failed to write message to dead letter queue", exc_info=e)


class Topic:
_topics: Dict[str, "Topic"] = {}
Expand Down Expand Up @@ -93,7 +171,8 @@ def clear(self):

async def publish(self, message: MemoryMessage) -> None:
for subscription in self._subscriptions:
await subscription.deliver(message)
# Ensure that each subscription gets its own copy of the message
await subscription.deliver(copy.deepcopy(message))


@asynccontextmanager
Expand Down
49 changes: 48 additions & 1 deletion tests/server/services/test_task_run_recorder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from datetime import timedelta
from itertools import permutations
from pathlib import Path
from typing import AsyncGenerator
from uuid import UUID

Expand All @@ -18,7 +19,7 @@
from prefect.server.schemas.core import FlowRun, TaskRunPolicy
from prefect.server.schemas.states import StateDetails, StateType
from prefect.server.services import task_run_recorder
from prefect.server.utilities.messaging import MessageHandler
from prefect.server.utilities.messaging import MessageHandler, create_publisher
from prefect.server.utilities.messaging.memory import MemoryMessage


Expand All @@ -29,6 +30,7 @@ async def test_start_and_stop_service():

await service.started_event.wait()
assert service.consumer_task is not None
assert service.consumer is not None

await service.stop()
assert service.consumer_task is None
Expand Down Expand Up @@ -753,3 +755,48 @@ async def test_task_run_recorder_handles_all_out_of_order_permutations(

state_types = set(state.type for state in states)
assert state_types == {StateType.PENDING, StateType.RUNNING, StateType.COMPLETED}


async def test_task_run_recorder_sends_repeated_failed_messages_to_dead_letter(
pending_event: ReceivedEvent,
tmp_path: Path,
):
"""
Test to ensure situations like the one described in https://github.com/PrefectHQ/prefect/issues/15607
don't overwhelm the task run recorder.
"""
pending_transition_time = pendulum.datetime(2024, 1, 1, 0, 0, 0, 0, "UTC")
assert pending_event.occurred == pending_transition_time

service = task_run_recorder.TaskRunRecorder()

service_task = asyncio.create_task(service.start())
await service.started_event.wait()
service.consumer.subscription.dead_letter_queue_path = tmp_path / "dlq"

async with create_publisher("events") as publisher:
await publisher.publish_data(
message(pending_event).data, message(pending_event).attributes
)
# Sending a task run event with the same task run id and timestamp but
# a different id will raise an error when trying to insert it into the
# database
duplicate_pending_event = pending_event.model_copy()
duplicate_pending_event.id = UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
await publisher.publish_data(
message(duplicate_pending_event).data,
message(duplicate_pending_event).attributes,
)

while not list(service.consumer.subscription.dead_letter_queue_path.glob("*")):
await asyncio.sleep(0.1)

assert (
len(list(service.consumer.subscription.dead_letter_queue_path.glob("*"))) == 1
)

service_task.cancel()
try:
await service_task
except asyncio.CancelledError:
pass
57 changes: 57 additions & 0 deletions tests/server/utilities/test_messaging.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import importlib
import json
from pathlib import Path
from typing import (
AsyncContextManager,
AsyncGenerator,
Expand All @@ -24,6 +26,12 @@
create_publisher,
ephemeral_subscription,
)
from prefect.server.utilities.messaging.memory import (
Consumer as MemoryConsumer,
)
from prefect.server.utilities.messaging.memory import (
MemoryMessage,
)
from prefect.settings import (
PREFECT_MESSAGING_BROKER,
PREFECT_MESSAGING_CACHE,
Expand Down Expand Up @@ -441,3 +449,52 @@ async def handler(message: Message):
# TODO: is there a way we can test that ephemeral subscriptions really have cleaned
# up after themselves after they have exited? This will differ significantly by
# each broker implementation, so it's hard to write a generic test.


async def test_repeatedly_failed_message_is_moved_to_dead_letter_queue(
deduplicating_publisher: Publisher,
consumer: MemoryConsumer,
tmp_path: Path,
):
captured_messages: List[Message] = []

async def handler(message: Message):
captured_messages.append(message)
raise ValueError("Simulated failure")

consumer.subscription.dead_letter_queue_path = tmp_path / "dlq"

consumer_task = asyncio.create_task(consumer.run(handler))

async with deduplicating_publisher as p:
await p.publish_data(
b"hello, world", {"howdy": "partner", "my-message-id": "A"}
)

while not list(consumer.subscription.dead_letter_queue_path.glob("*")):
await asyncio.sleep(0.1)

try:
consumer_task.cancel()
await consumer_task
except asyncio.CancelledError:
pass

# Message should have been moved to DLQ after multiple retries
assert len(captured_messages) == 4 # Original attempt + 3 retries
for message in captured_messages:
assert message.data == b"hello, world"
assert message.attributes == {"howdy": "partner", "my-message-id": "A"}

# Verify message is in DLQ
assert len(list(consumer.subscription.dead_letter_queue_path.glob("*"))) == 1
dlq_message_file = next(
iter(consumer.subscription.dead_letter_queue_path.glob("*"))
)
dlq_message = MemoryMessage(**json.loads(dlq_message_file.read_text()))
assert dlq_message.data == "hello, world"
assert dlq_message.attributes == {"howdy": "partner", "my-message-id": "A"}
assert dlq_message.retry_count > 3

remaining_message = await drain_one(consumer)
assert not remaining_message

0 comments on commit 39b6028

Please sign in to comment.