diff --git a/consumer_example.py b/consumer_example.py index 57bb721..52a2ffb 100644 --- a/consumer_example.py +++ b/consumer_example.py @@ -1,13 +1,69 @@ -from message_flow import MessageFlow, Message, Payload, Header - +from types import TracebackType +from typing import Type +from message_flow import MessageFlow, Message, Payload, Header, BaseMiddleware +from message_flow.utils import logger class OrderCreated(Message): order_id: str = Payload() tenant_id: str = Header() +class MockMiddleware(BaseMiddleware): + def on_consume(self) -> None: + self.headers.update({"a": 123}) + print("h") + return super().on_consume() + + def after_consume(self, error: Exception | None = None) -> None: + print("11") + return super().after_consume(error) + + +class MockMiddleware1(BaseMiddleware): + def on_consume(self) -> None: + self.headers.update({"a": 123}) + print("h1") + return super().on_consume() + + def after_consume(self, error: Exception | None = None) -> None: + print("1122") + return super().after_consume(error) + + def on_produce(self) -> None: + print("produce") + return super().on_produce() + + def after_produce(self, error: Exception | None = None) -> None: + print("after produce") + return super().after_produce(error) + + +class CustomMiddleware(BaseMiddleware): + def on_consume(self) -> None: + logger.info("Message with %s payload and %s headers received.", self.payload, self.headers) + return super().on_consume() + + def after_consume(self, error: Exception | None = None) -> None: + logger.info("Message with %s payload and %s headers processed.", self.payload, self.headers) + return super().after_consume(error) + + def on_produce(self) -> None: + logger.info("Producing message with %s payload and %s headers.", self.payload, self.headers) + return super().on_produce() + + def after_produce(self, error: Exception | None = None) -> None: + logger.info("Message with %s payload and %s headers produced.", self.payload, self.headers) + return super().after_produce(error) + + app = MessageFlow() +app.add_middleware(MockMiddleware) +app.add_middleware(MockMiddleware1) +app.add_middleware(CustomMiddleware) @app.subscribe(address="orders", message=OrderCreated) def order_created_handler(order_created: OrderCreated) -> None: print("Event received", order_created.order_id, order_created.tenant_id) + return order_created + +app.dispatch() \ No newline at end of file diff --git a/producer_example.py b/producer_example.py index 5c8c7bf..6445e41 100644 --- a/producer_example.py +++ b/producer_example.py @@ -9,5 +9,5 @@ class OrderCreated(Message): if __name__ == "__main__": app = MessageFlow() - app.publish(OrderCreated(order_id="order_id", tenant_id="tenant_id"), channel_address="orders") + app.send(OrderCreated(order_id="order_id", tenant_id="tenant_id"), channel_address="orders", reply_to_address="orders_") \ No newline at end of file diff --git a/src/message_flow/app/__init__.pyi b/src/message_flow/app/__init__.pyi index f225ce0..db28e46 100644 --- a/src/message_flow/app/__init__.pyi +++ b/src/message_flow/app/__init__.pyi @@ -1,2 +1,3 @@ +from .base_middleware import * from .message_flow import * from .messaging import * diff --git a/src/message_flow/app/_message_management/dispatcher.py b/src/message_flow/app/_message_management/dispatcher.py index 3e824eb..694cbe3 100644 --- a/src/message_flow/app/_message_management/dispatcher.py +++ b/src/message_flow/app/_message_management/dispatcher.py @@ -1,8 +1,10 @@ import logging +from contextlib import ExitStack from typing import final from ...utils import internal from .._internal import Channels +from ..base_middleware import BaseMiddleware from ..messaging import MessageConsumer from .producer import Producer from .routing_headers import RoutingHeaders @@ -12,13 +14,18 @@ @internal class Dispatcher: def __init__( - self, channels: Channels, message_consumer: MessageConsumer, producer: Producer, logger: logging.Logger + self, + channels: Channels, + message_consumer: MessageConsumer, + producer: Producer, + logger: logging.Logger, ) -> None: self._logger = logger self._channels = channels self._message_consumer = message_consumer self._producer = producer + self._middlewares: list[type[BaseMiddleware]] = [] def initialize(self) -> None: self._logger.debug("Initializing dispatcher") @@ -28,11 +35,27 @@ def initialize(self) -> None: ) self._logger.debug("Initialized dispatcher") + def add_middleware(self, middleware: type[BaseMiddleware]) -> None: + self._middlewares.append(middleware) + def message_handler(self, payload: bytes, headers: dict[str, str]) -> None: if ( handler := self._channels.operation_of(headers[RoutingHeaders.ADDRESS], headers[RoutingHeaders.TYPE]) ) is None: return - if (message := handler(handler.message.from_payload_and_headers(payload, headers))) is not None: - self._producer.send(headers[RoutingHeaders.REPLY_TO], message) + with ExitStack() as dispatcher_stack: + self._execute_consume_middlewares(dispatcher_stack, payload, headers) + + if (message := handler(handler.message.from_payload_and_headers(payload, headers))) is not None: + self._execute_produce_middlewares(dispatcher_stack, message.payload, message.headers) + + self._producer.send(headers[RoutingHeaders.REPLY_TO], message) + + def _execute_consume_middlewares(self, stack: ExitStack, payload: bytes, headers: dict[str, str]) -> None: + for middleware in self._middlewares: + stack.enter_context(middleware(payload, headers).consume()) + + def _execute_produce_middlewares(self, stack: ExitStack, payload: bytes, headers: dict[str, str]) -> None: + for middleware in self._middlewares: + stack.enter_context(middleware(payload, headers).produce()) diff --git a/src/message_flow/app/base_middleware.py b/src/message_flow/app/base_middleware.py new file mode 100644 index 0000000..9fed2d3 --- /dev/null +++ b/src/message_flow/app/base_middleware.py @@ -0,0 +1,55 @@ +from contextlib import contextmanager +from typing import Generator + +from ..utils import external + + +@external +class BaseMiddleware: + def __init__(self, payload: bytes, headers: dict[str, str]) -> None: + self.payload = payload + self.headers = headers + + def on_consume(self) -> None: + pass + + def after_consume( + self, + error: Exception | None = None, + ) -> None: + if error is not None: + raise error + + @contextmanager + def consume(self) -> Generator[None, None, None]: + consume_error: Exception | None = None + + try: + self.on_consume() + yield + except Exception as error: + consume_error = error + + self.after_consume(consume_error) + + def on_produce(self) -> None: + pass + + def after_produce( + self, + error: Exception | None = None, + ) -> None: + if error is not None: + raise error + + @contextmanager + def produce(self) -> Generator[None, None, None]: + produce_error: Exception | None = None + + try: + self.on_produce() + yield + except Exception as error: + produce_error = error + + self.after_produce(produce_error) diff --git a/src/message_flow/app/message_flow.py b/src/message_flow/app/message_flow.py index 2efe69e..9dbc3b2 100644 --- a/src/message_flow/app/message_flow.py +++ b/src/message_flow/app/message_flow.py @@ -12,6 +12,7 @@ from ._internal import AsyncAPIStudioPage, Channels, Info, MessageFlowSchema from ._message_management import Dispatcher, Producer from ._simple_messaging import SimpleMessageConsumer, SimpleMessageProducer +from .base_middleware import BaseMiddleware from .messaging import MessageConsumer, MessageProducer MessageHandler = Callable[[Message], Message | None] @@ -556,3 +557,42 @@ async def async_api_docs_html(req: Request) -> HTMLResponse: return HTMLResponse(documentation_page) fast_api.add_route(documentation_url, async_api_docs_html, include_in_schema=False) + + def add_middleware( + self, middleware: Annotated[type[BaseMiddleware], Doc("Message processing Middleware.")] + ) -> None: + """ + Add Middleware. + + **Example** + + ```python title="Adding processing Middleware" + import logging + from message_flow import MessageFlow, BaseMiddleware + + logger = logging.getLogger(__name__) + + + class CustomMiddleware(BaseMiddleware): + def on_consume(self) -> None: + logger.info("Message with %s payload and %s headers received.", self.payload, self.headers) + return super().on_consume() + + def after_consume(self, error: Exception | None = None) -> None: + logger.info("Message with %s payload and %s headers processed.", self.payload, self.headers) + return super().after_consume(error) + + def on_produce(self) -> None: + logger.info("Producing message with %s payload and %s headers.", self.payload, self.headers) + return super().on_produce() + + def after_produce(self, error: Exception | None = None) -> None: + logger.info("Message with %s payload and %s headers produced.", self.payload, self.headers) + return super().after_produce(error) + + app = MessageFlow() + app.add_middleware(CustomMiddleware) + + ``` + """ + self.dispatcher.add_middleware(middleware=middleware) diff --git a/src/message_flow/cli/_cli_app.py b/src/message_flow/cli/_cli_app.py index 875b8e9..c6315be 100644 --- a/src/message_flow/cli/_cli_app.py +++ b/src/message_flow/cli/_cli_app.py @@ -2,7 +2,7 @@ from collections import defaultdict from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from typing import DefaultDict +from typing import DefaultDict, final import typer @@ -12,6 +12,7 @@ from ._logging_level import LoggingLevel +@final @internal class CLIApp: LOGGING_LEVELS: DefaultDict[str, int] = defaultdict( diff --git a/src/message_flow/cli/_documentation_server.py b/src/message_flow/cli/_documentation_server.py index 2db8b16..6342f7d 100644 --- a/src/message_flow/cli/_documentation_server.py +++ b/src/message_flow/cli/_documentation_server.py @@ -1,8 +1,10 @@ from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import final from ..utils import internal, logger +@final @internal class DocumentationServer: def __init__(self, studio_page: str, host: str, port: int) -> None: diff --git a/src/message_flow/cli/_logging_level.py b/src/message_flow/cli/_logging_level.py index a9146d4..648e5ef 100644 --- a/src/message_flow/cli/_logging_level.py +++ b/src/message_flow/cli/_logging_level.py @@ -1,8 +1,10 @@ from enum import Enum +from typing import final from ..utils import internal +@final @internal class LoggingLevel(str, Enum): CRITICAL = "critical"