From 2825ebbd178e099b3b8a11b0ac6ea9497168b262 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexis=20M=C3=A9taireau?= Date: Mon, 3 Jun 2024 19:46:22 +0200 Subject: [PATCH] tests(sync): Change the way the websocket server is run in the tests Using [pytest-xprocess](https://pytest-xprocess.readthedocs.io/) proved not being as useful as I thought at first, because it was causing intermitent failures when starting the process. The code now directly uses `subprocess.popen` calls to start the server. The tests are grouped together using the following decorator: `@pytest.mark.xdist_group(name="websockets")` Tests now need to be run with the `pytest --dist loadgroup` so that all tests of the same group happen on the same process. More details on this blogpost: https://blog.notmyidea.org/start-a-process-when-using-pytest-xdist.html --- Makefile | 4 +- pyproject.toml | 1 - umap/tests/integration/conftest.py | 39 ++++---- umap/tests/integration/test_websocket_sync.py | 6 ++ umap/tests/test_datalayer_views.py | 1 - umap/websocket_server.py | 92 +++++++++++++++++++ 6 files changed, 122 insertions(+), 21 deletions(-) create mode 100644 umap/websocket_server.py diff --git a/Makefile b/Makefile index 5de3fd408..a774a4c4c 100644 --- a/Makefile +++ b/Makefile @@ -58,10 +58,10 @@ publish: ## Publish the Python package to Pypi test: testpy testjs testpy: - pytest -vv umap/tests/ + pytest -vv umap/tests/ --dist=loadgroup test-integration: - pytest -xv umap/tests/integration/ + pytest -xv umap/tests/integration/ --dist=loadgroup clean: rm -f dist/* diff --git a/pyproject.toml b/pyproject.toml index dbf0914fe..2b4852d2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,6 @@ test = [ "pytest-django==4.8.0", "pytest-playwright==0.5.0", "pytest-xdist>=3.5.0,<4", - "pytest-xprocess>=1.0.1", ] docker = [ "uwsgi==2.0.26", diff --git a/umap/tests/integration/conftest.py b/umap/tests/integration/conftest.py index 06367086b..b808a7070 100644 --- a/umap/tests/integration/conftest.py +++ b/umap/tests/integration/conftest.py @@ -1,9 +1,10 @@ import os +import subprocess +import time from pathlib import Path import pytest from playwright.sync_api import expect -from xprocess import ProcessStarter @pytest.fixture(autouse=True) @@ -37,19 +38,23 @@ def do_login(user): return do_login -@pytest.fixture() -def websocket_server(xprocess): - class Starter(ProcessStarter): - settings_path = ( - (Path(__file__).parent.parent / "settings.py").absolute().as_posix() - ) - os.environ["UMAP_SETTINGS"] = settings_path - # env = {"UMAP_SETTINGS": settings_path} - pattern = "Waiting for connections*" - args = ["python", "-m", "umap.ws"] - timeout = 1 - terminate_on_interrupt = True - - xprocess.ensure("websocket_server", Starter) - yield - xprocess.getinfo("websocket_server").terminate() +@pytest.fixture +def websocket_server(): + # Find the test-settings, and put them in the current environment + settings_path = (Path(__file__).parent.parent / "settings.py").absolute().as_posix() + os.environ["UMAP_SETTINGS"] = settings_path + + ds_proc = subprocess.Popen( + [ + "umap", + "run_websocket_server", + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + time.sleep(2) + # Ensure it started properly before yielding + assert not ds_proc.poll(), ds_proc.stdout.read().decode("utf-8") + yield ds_proc + # Shut it down at the end of the pytest session + ds_proc.terminate() diff --git a/umap/tests/integration/test_websocket_sync.py b/umap/tests/integration/test_websocket_sync.py index 87842340f..604bdc6ac 100644 --- a/umap/tests/integration/test_websocket_sync.py +++ b/umap/tests/integration/test_websocket_sync.py @@ -1,5 +1,6 @@ import re +import pytest from playwright.sync_api import expect from umap.models import Map @@ -9,6 +10,7 @@ DATALAYER_UPDATE = re.compile(r".*/datalayer/update/.*") +@pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_markers( context, live_server, websocket_server, tilelayer ): @@ -73,6 +75,7 @@ def test_websocket_connection_can_sync_markers( expect(b_marker_pane).to_have_count(1) +@pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_polygons( context, live_server, websocket_server, tilelayer ): @@ -156,6 +159,7 @@ def test_websocket_connection_can_sync_polygons( expect(b_polygons).to_have_count(0) +@pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_map_properties( context, live_server, websocket_server, tilelayer ): @@ -187,6 +191,7 @@ def test_websocket_connection_can_sync_map_properties( expect(peerA.locator(".leaflet-control-zoom")).to_be_hidden() +@pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_datalayer_properties( context, live_server, websocket_server, tilelayer ): @@ -215,6 +220,7 @@ def test_websocket_connection_can_sync_datalayer_properties( expect(peerB.get_by_role("combobox")).to_have_value("Choropleth") +@pytest.mark.xdist_group(name="websockets") def test_websocket_connection_can_sync_cloned_polygons( context, live_server, websocket_server, tilelayer ): diff --git a/umap/tests/test_datalayer_views.py b/umap/tests/test_datalayer_views.py index 6055e47ed..c59a12fba 100644 --- a/umap/tests/test_datalayer_views.py +++ b/umap/tests/test_datalayer_views.py @@ -1,5 +1,4 @@ import json -import time from copy import deepcopy from pathlib import Path diff --git a/umap/websocket_server.py b/umap/websocket_server.py new file mode 100644 index 000000000..3ba81394f --- /dev/null +++ b/umap/websocket_server.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python + +import asyncio +from collections import defaultdict +from typing import Literal, Optional + +import websockets +from django.conf import settings +from django.core.signing import TimestampSigner +from pydantic import BaseModel, ValidationError +from websockets import WebSocketClientProtocol +from websockets.server import serve + +from umap.models import Map, User # NOQA + +# Contains the list of websocket connections handled by this process. +# It's a mapping of map_id to a set of the active websocket connections +CONNECTIONS = defaultdict(set) + + +class JoinMessage(BaseModel): + kind: str = "join" + token: str + + +class OperationMessage(BaseModel): + kind: str = "operation" + verb: str = Literal["upsert", "update", "delete"] + subject: str = Literal["map", "layer", "feature"] + metadata: Optional[dict] = None + key: Optional[str] = None + + +async def join_and_listen( + map_id: int, permissions: list, user: str | int, websocket: WebSocketClientProtocol +): + """Join a "room" whith other connected peers. + + New messages will be broadcasted to other connected peers. + """ + print(f"{user} joined room #{map_id}") + CONNECTIONS[map_id].add(websocket) + try: + async for raw_message in websocket: + # recompute the peers-list at the time of message-sending. + # as doing so beforehand would miss new connections + peers = CONNECTIONS[map_id] - {websocket} + # Only relay valid "operation" messages + try: + OperationMessage.model_validate_json(raw_message) + websockets.broadcast(peers, raw_message) + except ValidationError as e: + error = f"An error occurred when receiving this message: {raw_message}" + print(error, e) + finally: + CONNECTIONS[map_id].remove(websocket) + + +async def handler(websocket): + """Main WebSocket handler. + + If permissions are granted, let the peer enter a room. + """ + raw_message = await websocket.recv() + + # The first event should always be 'join' + message: JoinMessage = JoinMessage.model_validate_json(raw_message) + signed = TimestampSigner().unsign_object(message.token, max_age=30) + user, map_id, permissions = signed.values() + + # Check if permissions for this map have been granted by the server + if "edit" in signed["permissions"]: + await join_and_listen(map_id, permissions, user, websocket) + + +def run(host, port): + if not settings.WEBSOCKET_ENABLED: + msg = ( + "WEBSOCKET_ENABLED should be set to True to run the WebSocket Server. " + "See the documentation at " + "https://docs.umap-project.org/en/stable/config/settings/#websocket_enabled " + "for more information." + ) + print(msg) + exit(1) + + async def _serve(): + async with serve(handler, host, port): + print(f"Waiting for connections on {host}:{port}") + await asyncio.Future() # run forever + + asyncio.run(_serve())