diff --git a/src/aap_eda/wsapi/consumers.py b/src/aap_eda/wsapi/consumers.py index 71129f40a..f4c2bc6f9 100644 --- a/src/aap_eda/wsapi/consumers.py +++ b/src/aap_eda/wsapi/consumers.py @@ -72,12 +72,31 @@ class Event(Enum): DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f%z" +class DefaultConsumer(AsyncWebsocketConsumer): + """Default consumer for websocket connections. + + This is the consumer to handle all the unexpected paths, it will close the + connection with an error message. + """ + + async def connect(self): + await self.accept() + await self.send('{"error": "invalid path"}') + await self.close() + + class AnsibleRulebookConsumer(AsyncWebsocketConsumer): async def receive(self, text_data=None, bytes_data=None): data = json.loads(text_data) logger.debug(f"AnsibleRulebookConsumer received: {data}") - msg_type = MessageType(data.get("type")) + msg_type = data.get("type") + try: + msg_type = MessageType(data.get("type")) + except ValueError: + logger.error(f"Unsupported message type: {data}") + await self.close() + return try: if msg_type == MessageType.WORKER: @@ -94,8 +113,6 @@ async def receive(self, text_data=None, bytes_data=None): logger.info("Websocket connection is closed.") elif msg_type == MessageType.SESSION_STATS: await self.handle_heartbeat(HeartbeatMessage.parse_obj(data)) - else: - logger.warning(f"Unsupported message received: {data}") except DatabaseError as err: logger.error(f"Failed to parse {data} due to DB error: {err}") diff --git a/src/aap_eda/wsapi/routes.py b/src/aap_eda/wsapi/routes.py index 1ae995668..d83aa4031 100644 --- a/src/aap_eda/wsapi/routes.py +++ b/src/aap_eda/wsapi/routes.py @@ -1,15 +1,28 @@ from channels.routing import URLRouter from django.conf import settings -from django.urls import path +from django.urls import path, re_path from . import consumers +default_path = re_path(r".*/?$", consumers.DefaultConsumer.as_asgi()) + + +default_router = URLRouter( + [ + default_path, + ], +) + wsapi_router = URLRouter( - [path("ansible-rulebook", consumers.AnsibleRulebookConsumer.as_asgi())] + [ + path("ansible-rulebook", consumers.AnsibleRulebookConsumer.as_asgi()), + default_path, + ], ) router = URLRouter( [ path(f"{settings.API_PREFIX}/ws/", wsapi_router), - ] + path("", default_router), + ], ) diff --git a/tests/integration/wsapi/test_consumer.py b/tests/integration/wsapi/test_consumer.py index fa7826359..6b68d53d0 100644 --- a/tests/integration/wsapi/test_consumer.py +++ b/tests/integration/wsapi/test_consumer.py @@ -5,9 +5,11 @@ import pytest_asyncio from channels.db import database_sync_to_async from channels.testing import WebsocketCommunicator +from django.conf import settings from django.utils import timezone from pydantic.error_wrappers import ValidationError +from aap_eda.asgi import application from aap_eda.core import models from aap_eda.wsapi.consumers import AnsibleRulebookConsumer @@ -45,6 +47,42 @@ DUMMY_UUID2 = "8472ff2c-6045-4418-8d4e-46f6cfffffff" +@pytest.mark.parametrize("path", ["ws/unexpected", "unexpected"]) +async def test_invalid_websocket_route(path: str): + """Test that the websocket consumer rejects unsupported routes.""" + communicator = WebsocketCommunicator(application, path) + + connected, _ = await communicator.connect() + assert connected, "Connection failed" + + response = await communicator.receive_from() + assert response == '{"error": "invalid path"}', "Invalid error message" + close_message = await communicator.receive_output() + assert ( + close_message["type"] == "websocket.close" + ), "Did not receive close message" + + await communicator.disconnect() + + +async def test_valid_websocket_route_wrong_type(): + """Test that the websocket consumer rejects unsupported types.""" + communicator = WebsocketCommunicator( + application, + f"{settings.API_PREFIX}/ws/ansible-rulebook", + ) + connected, _ = await communicator.connect() + assert connected, "Connection failed" + nothing = await communicator.receive_nothing() + assert nothing, "Received unexpected message" + await communicator.send_to(text_data='{"type": "unsuported_type"}') + close_message = await communicator.receive_output() + assert ( + close_message["type"] == "websocket.close" + ), "Did not receive close message" + await communicator.disconnect() + + @pytest.mark.django_db(transaction=True) async def test_handle_workers(ws_communicator: WebsocketCommunicator): activation_instance_with_extra_var = await _prepare_db_data()