Skip to content

Commit

Permalink
Add Middlewares to the Message Flow
Browse files Browse the repository at this point in the history
  • Loading branch information
voro6yov committed Apr 6, 2024
1 parent 9936e1a commit d4eaf9e
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 7 deletions.
60 changes: 58 additions & 2 deletions consumer_example.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,69 @@
from message_flow import MessageFlow, Message, Payload, Header

from types import TracebackType
from typing import Type
from message_flow import MessageFlow, Message, Payload, Header, BaseMiddleware
from message_flow.utils import logger

class OrderCreated(Message):
order_id: str = Payload()
tenant_id: str = Header()


class MockMiddleware(BaseMiddleware):
def on_consume(self) -> None:
self.headers.update({"a": 123})
print("h")
return super().on_consume()

def after_consume(self, error: Exception | None = None) -> None:
print("11")
return super().after_consume(error)


class MockMiddleware1(BaseMiddleware):
def on_consume(self) -> None:
self.headers.update({"a": 123})
print("h1")
return super().on_consume()

def after_consume(self, error: Exception | None = None) -> None:
print("1122")
return super().after_consume(error)

def on_produce(self) -> None:
print("produce")
return super().on_produce()

def after_produce(self, error: Exception | None = None) -> None:
print("after produce")
return super().after_produce(error)


class CustomMiddleware(BaseMiddleware):
def on_consume(self) -> None:
logger.info("Message with %s payload and %s headers received.", self.payload, self.headers)
return super().on_consume()

def after_consume(self, error: Exception | None = None) -> None:
logger.info("Message with %s payload and %s headers processed.", self.payload, self.headers)
return super().after_consume(error)

def on_produce(self) -> None:
logger.info("Producing message with %s payload and %s headers.", self.payload, self.headers)
return super().on_produce()

def after_produce(self, error: Exception | None = None) -> None:
logger.info("Message with %s payload and %s headers produced.", self.payload, self.headers)
return super().after_produce(error)


app = MessageFlow()
app.add_middleware(MockMiddleware)
app.add_middleware(MockMiddleware1)
app.add_middleware(CustomMiddleware)

@app.subscribe(address="orders", message=OrderCreated)
def order_created_handler(order_created: OrderCreated) -> None:
print("Event received", order_created.order_id, order_created.tenant_id)
return order_created

app.dispatch()
2 changes: 1 addition & 1 deletion producer_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ class OrderCreated(Message):
if __name__ == "__main__":
app = MessageFlow()

app.publish(OrderCreated(order_id="order_id", tenant_id="tenant_id"), channel_address="orders")
app.send(OrderCreated(order_id="order_id", tenant_id="tenant_id"), channel_address="orders", reply_to_address="orders_")

1 change: 1 addition & 0 deletions src/message_flow/app/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base_middleware import *
from .message_flow import *
from .messaging import *
29 changes: 26 additions & 3 deletions src/message_flow/app/_message_management/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
from contextlib import ExitStack
from typing import final

from ...utils import internal
from .._internal import Channels
from ..base_middleware import BaseMiddleware
from ..messaging import MessageConsumer
from .producer import Producer
from .routing_headers import RoutingHeaders
Expand All @@ -12,13 +14,18 @@
@internal
class Dispatcher:
def __init__(
self, channels: Channels, message_consumer: MessageConsumer, producer: Producer, logger: logging.Logger
self,
channels: Channels,
message_consumer: MessageConsumer,
producer: Producer,
logger: logging.Logger,
) -> None:
self._logger = logger

self._channels = channels
self._message_consumer = message_consumer
self._producer = producer
self._middlewares: list[type[BaseMiddleware]] = []

def initialize(self) -> None:
self._logger.debug("Initializing dispatcher")
Expand All @@ -28,11 +35,27 @@ def initialize(self) -> None:
)
self._logger.debug("Initialized dispatcher")

def add_middleware(self, middleware: type[BaseMiddleware]) -> None:
self._middlewares.append(middleware)

def message_handler(self, payload: bytes, headers: dict[str, str]) -> None:
if (
handler := self._channels.operation_of(headers[RoutingHeaders.ADDRESS], headers[RoutingHeaders.TYPE])
) is None:
return

if (message := handler(handler.message.from_payload_and_headers(payload, headers))) is not None:
self._producer.send(headers[RoutingHeaders.REPLY_TO], message)
with ExitStack() as dispatcher_stack:
self._execute_consume_middlewares(dispatcher_stack, payload, headers)

if (message := handler(handler.message.from_payload_and_headers(payload, headers))) is not None:
self._execute_produce_middlewares(dispatcher_stack, message.payload, message.headers)

self._producer.send(headers[RoutingHeaders.REPLY_TO], message)

def _execute_consume_middlewares(self, stack: ExitStack, payload: bytes, headers: dict[str, str]) -> None:
for middleware in self._middlewares:
stack.enter_context(middleware(payload, headers).consume())

def _execute_produce_middlewares(self, stack: ExitStack, payload: bytes, headers: dict[str, str]) -> None:
for middleware in self._middlewares:
stack.enter_context(middleware(payload, headers).produce())
55 changes: 55 additions & 0 deletions src/message_flow/app/base_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from contextlib import contextmanager
from typing import Generator

from ..utils import external


@external
class BaseMiddleware:
def __init__(self, payload: bytes, headers: dict[str, str]) -> None:
self.payload = payload
self.headers = headers

def on_consume(self) -> None:
pass

def after_consume(
self,
error: Exception | None = None,
) -> None:
if error is not None:
raise error

@contextmanager
def consume(self) -> Generator[None, None, None]:
consume_error: Exception | None = None

try:
self.on_consume()
yield
except Exception as error:
consume_error = error

self.after_consume(consume_error)

def on_produce(self) -> None:
pass

def after_produce(
self,
error: Exception | None = None,
) -> None:
if error is not None:
raise error

@contextmanager
def produce(self) -> Generator[None, None, None]:
produce_error: Exception | None = None

try:
self.on_produce()
yield
except Exception as error:
produce_error = error

self.after_produce(produce_error)
40 changes: 40 additions & 0 deletions src/message_flow/app/message_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ._internal import AsyncAPIStudioPage, Channels, Info, MessageFlowSchema
from ._message_management import Dispatcher, Producer
from ._simple_messaging import SimpleMessageConsumer, SimpleMessageProducer
from .base_middleware import BaseMiddleware
from .messaging import MessageConsumer, MessageProducer

MessageHandler = Callable[[Message], Message | None]
Expand Down Expand Up @@ -556,3 +557,42 @@ async def async_api_docs_html(req: Request) -> HTMLResponse:
return HTMLResponse(documentation_page)

fast_api.add_route(documentation_url, async_api_docs_html, include_in_schema=False)

def add_middleware(
self, middleware: Annotated[type[BaseMiddleware], Doc("Message processing Middleware.")]
) -> None:
"""
Add Middleware.
**Example**
```python title="Adding processing Middleware"
import logging
from message_flow import MessageFlow, BaseMiddleware
logger = logging.getLogger(__name__)
class CustomMiddleware(BaseMiddleware):
def on_consume(self) -> None:
logger.info("Message with %s payload and %s headers received.", self.payload, self.headers)
return super().on_consume()
def after_consume(self, error: Exception | None = None) -> None:
logger.info("Message with %s payload and %s headers processed.", self.payload, self.headers)
return super().after_consume(error)
def on_produce(self) -> None:
logger.info("Producing message with %s payload and %s headers.", self.payload, self.headers)
return super().on_produce()
def after_produce(self, error: Exception | None = None) -> None:
logger.info("Message with %s payload and %s headers produced.", self.payload, self.headers)
return super().after_produce(error)
app = MessageFlow()
app.add_middleware(CustomMiddleware)
```
"""
self.dispatcher.add_middleware(middleware=middleware)
3 changes: 2 additions & 1 deletion src/message_flow/cli/_cli_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from typing import DefaultDict
from typing import DefaultDict, final

import typer

Expand All @@ -12,6 +12,7 @@
from ._logging_level import LoggingLevel


@final
@internal
class CLIApp:
LOGGING_LEVELS: DefaultDict[str, int] = defaultdict(
Expand Down
2 changes: 2 additions & 0 deletions src/message_flow/cli/_documentation_server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import final

from ..utils import internal, logger


@final
@internal
class DocumentationServer:
def __init__(self, studio_page: str, host: str, port: int) -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/message_flow/cli/_logging_level.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from enum import Enum
from typing import final

from ..utils import internal


@final
@internal
class LoggingLevel(str, Enum):
CRITICAL = "critical"
Expand Down

0 comments on commit d4eaf9e

Please sign in to comment.