Skip to content

Commit

Permalink
fix(dlq): Make InvalidMessage pickleable (#284)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhars authored Sep 13, 2023
1 parent 55e6f79 commit 5b7df38
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 0 deletions.
3 changes: 3 additions & 0 deletions arroyo/dlq.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __eq__(self, other: Any) -> bool:
and self.needs_commit == other.needs_commit
)

def __reduce__(self) -> Tuple[Any, Tuple[Any, ...]]:
return self.__class__, (self.partition, self.offset, self.needs_commit)


@dataclass(frozen=True)
class DlqLimit:
Expand Down
24 changes: 24 additions & 0 deletions tests/processing/strategies/test_run_task_with_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from arroyo.backends.kafka import KafkaPayload
from arroyo.dlq import InvalidMessage
from arroyo.processing.strategies import MessageRejected
from arroyo.processing.strategies.run_task_with_multiprocessing import (
MessageBatch,
Expand Down Expand Up @@ -631,3 +632,26 @@ def test_output_block_resizing_without_limits() -> None:
)
in TestingMetricsBackend.calls
)


def message_processor_raising_invalid_message(x: Message[KafkaPayload]) -> KafkaPayload:
raise InvalidMessage(Partition(topic=Topic("test_topic"), index=0), offset=1000)


def test_multiprocessing_with_invalid_message() -> None:
next_step = Mock()

strategy = RunTaskWithMultiprocessing(
message_processor_raising_invalid_message,
next_step,
num_processes=2,
max_batch_size=1,
max_batch_time=60,
)

strategy.submit(Message(Value(KafkaPayload(None, b"x" * 10, []), {})))

strategy.poll()
strategy.close()
with pytest.raises(InvalidMessage):
strategy.join(timeout=3)
8 changes: 8 additions & 0 deletions tests/test_dlq.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
from datetime import datetime
from typing import Generator
from unittest.mock import ANY
Expand Down Expand Up @@ -109,3 +110,10 @@ def test_dlq_policy_wrapper() -> None:
)
wrapper.produce(message)
wrapper.flush({partition: 11})


def test_invalid_message_pickleable() -> None:
exc = InvalidMessage(Partition(Topic("test_topic"), 0), 2)
pickled_exc = pickle.dumps(exc)
unpickled_exc = pickle.loads(pickled_exc)
assert exc == unpickled_exc

0 comments on commit 5b7df38

Please sign in to comment.