Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[ws]: handle unsupported paths and message types #592

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions src/aap_eda/wsapi/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,31 @@ class Event(Enum):
DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f%z"


class DefaultConsumer(AsyncWebsocketConsumer):
"""Default consumer for websocket connections.

This is the consumer to handle all the unexpected paths, it will close the
connection with an error message.
"""

async def connect(self):
await self.accept()
await self.send('{"error": "invalid path"}')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Alex-Izquierdo Should we reveal this, it might be better to just close the connection

await self.close()


class AnsibleRulebookConsumer(AsyncWebsocketConsumer):
async def receive(self, text_data=None, bytes_data=None):
data = json.loads(text_data)
logger.debug(f"AnsibleRulebookConsumer received: {data}")

msg_type = MessageType(data.get("type"))
msg_type = data.get("type")
try:
msg_type = MessageType(data.get("type"))
Alex-Izquierdo marked this conversation as resolved.
Show resolved Hide resolved
except ValueError:
logger.error(f"Unsupported message type: {data}")
await self.close()
Alex-Izquierdo marked this conversation as resolved.
Show resolved Hide resolved
return

try:
if msg_type == MessageType.WORKER:
Expand All @@ -94,8 +113,6 @@ async def receive(self, text_data=None, bytes_data=None):
logger.info("Websocket connection is closed.")
elif msg_type == MessageType.SESSION_STATS:
await self.handle_heartbeat(HeartbeatMessage.parse_obj(data))
else:
logger.warning(f"Unsupported message received: {data}")
except DatabaseError as err:
logger.error(f"Failed to parse {data} due to DB error: {err}")

Expand Down
19 changes: 16 additions & 3 deletions src/aap_eda/wsapi/routes.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
from channels.routing import URLRouter
from django.conf import settings
from django.urls import path
from django.urls import path, re_path

from . import consumers

default_path = re_path(r".*/?$", consumers.DefaultConsumer.as_asgi())


default_router = URLRouter(
[
default_path,
],
)

wsapi_router = URLRouter(
[path("ansible-rulebook", consumers.AnsibleRulebookConsumer.as_asgi())]
[
path("ansible-rulebook", consumers.AnsibleRulebookConsumer.as_asgi()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Alex-Izquierdo @bzwei has a PR which will put authentication on this path #540

default_path,
],
)

router = URLRouter(
[
path(f"{settings.API_PREFIX}/ws/", wsapi_router),
]
path("", default_router),
],
)
38 changes: 38 additions & 0 deletions tests/integration/wsapi/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import pytest_asyncio
from channels.db import database_sync_to_async
from channels.testing import WebsocketCommunicator
from django.conf import settings
from django.utils import timezone
from pydantic.error_wrappers import ValidationError

from aap_eda.asgi import application
from aap_eda.core import models
from aap_eda.wsapi.consumers import AnsibleRulebookConsumer

Expand Down Expand Up @@ -45,6 +47,42 @@
DUMMY_UUID2 = "8472ff2c-6045-4418-8d4e-46f6cfffffff"


@pytest.mark.parametrize("path", ["ws/unexpected", "unexpected"])
async def test_invalid_websocket_route(path: str):
"""Test that the websocket consumer rejects unsupported routes."""
communicator = WebsocketCommunicator(application, path)

connected, _ = await communicator.connect()
assert connected, "Connection failed"

response = await communicator.receive_from()
assert response == '{"error": "invalid path"}', "Invalid error message"
close_message = await communicator.receive_output()
assert (
close_message["type"] == "websocket.close"
), "Did not receive close message"

await communicator.disconnect()


async def test_valid_websocket_route_wrong_type():
"""Test that the websocket consumer rejects unsupported types."""
communicator = WebsocketCommunicator(
application,
f"{settings.API_PREFIX}/ws/ansible-rulebook",
)
connected, _ = await communicator.connect()
assert connected, "Connection failed"
nothing = await communicator.receive_nothing()
assert nothing, "Received unexpected message"
await communicator.send_to(text_data='{"type": "unsuported_type"}')
close_message = await communicator.receive_output()
assert (
close_message["type"] == "websocket.close"
), "Did not receive close message"
await communicator.disconnect()


@pytest.mark.django_db(transaction=True)
async def test_handle_workers(ws_communicator: WebsocketCommunicator):
activation_instance_with_extra_var = await _prepare_db_data()
Expand Down