diff --git a/pyproject.toml b/pyproject.toml index f75747514..8c2423e9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ "Operating System :: OS Independent" ] dependencies = [ - "websockets>=10.4", + "websockets>=13.1", "numpy>=1.0.0", "msgspec>=0.18.6", "imageio>=2.0.0", diff --git a/src/viser/infra/_infra.py b/src/viser/infra/_infra.py index ac46520d6..77dddb549 100644 --- a/src/viser/infra/_infra.py +++ b/src/viser/infra/_infra.py @@ -6,6 +6,7 @@ import dataclasses import gzip import http +import logging import mimetypes import queue import threading @@ -16,12 +17,13 @@ import msgspec import rich -import websockets.connection +import websockets.asyncio.server import websockets.datastructures import websockets.exceptions -import websockets.server from typing_extensions import Literal, assert_never, override -from websockets.legacy.server import WebSocketServerProtocol +from websockets import Headers +from websockets.asyncio.server import ServerConnection +from websockets.http11 import Request, Response from ._async_message_buffer import AsyncMessageBuffer from ._messages import Message @@ -234,8 +236,9 @@ def __init__( self._http_server_root = http_server_root self._verbose = verbose self._client_api_version: Literal[0, 1] = client_api_version - self._shutdown_event = threading.Event() - self._ws_server: websockets.WebSocketServer | None = None + self._background_event_loop: asyncio.AbstractEventLoop | None = None + + self._stop_event: asyncio.Event | None = None self._client_state_from_id: dict[int, _ClientHandleState] = {} @@ -258,9 +261,9 @@ def start(self) -> None: def stop(self) -> None: """Stop the server.""" - assert self._ws_server is not None - self._ws_server.close() - self._ws_server = None + assert self._background_event_loop is not None + assert self._stop_event is not None + self._background_event_loop.call_soon_threadsafe(self._stop_event.set) def on_client_connect( self, cb: Callable[[WebsockClientConnection], None | Coroutine] @@ -298,6 +301,8 @@ def _background_worker(self, ready_sem: threading.Semaphore) -> None: # Need to make a new event loop for notebook compatbility. event_loop = asyncio.new_event_loop() asyncio.set_event_loop(event_loop) + self._stop_event = asyncio.Event() + self._background_event_loop = event_loop self._broadcast_buffer = AsyncMessageBuffer( event_loop, persistent_messages=True ) @@ -306,8 +311,20 @@ def _background_worker(self, ready_sem: threading.Semaphore) -> None: connection_count = 0 total_connections = 0 - async def serve(websocket: WebSocketServerProtocol) -> None: - """Server loop, run once per connection.""" + async def ws_handler( + connection: websockets.asyncio.server.ServerConnection, + ) -> None: + """Handler for websocket connections.""" + + # + # Suppress errors for: https://github.com/python-websockets/websockets/issues/1513 + # TODO: remove this when websockets behavior changes upstream. + class NoHttpErrors(logging.Filter): + def filter(self, record): + return not record.getMessage() == "opening handshake failed" + + connection.logger.logger.addFilter(NoHttpErrors()) # type: ignore + # async with count_lock: nonlocal connection_count @@ -352,18 +369,18 @@ def handle_incoming(message: Message) -> None: # and consumers (which receive messages). await asyncio.gather( _message_producer( - websocket, + connection, client_state.message_buffer, client_id, self._client_api_version, ), _message_producer( - websocket, + connection, self._broadcast_buffer, client_id, self._client_api_version, ), - _message_consumer(websocket, handle_incoming, message_class), + _message_consumer(connection, handle_incoming, message_class), ) except ( websockets.exceptions.ConnectionClosedOK, @@ -397,16 +414,16 @@ def handle_incoming(message: Message) -> None: file_cache: dict[Path, bytes] = {} file_cache_gzipped: dict[Path, bytes] = {} - async def viser_http_server( - path: str, request_headers: websockets.datastructures.Headers - ) -> ( - tuple[http.HTTPStatus, websockets.datastructures.HeadersLike, bytes] | None - ): + def viser_http_server( + connection: ServerConnection, + request: Request, + ) -> Response | None: # Ignore websocket packets. - if request_headers.get("Upgrade") == "websocket": + if request.headers.get("Upgrade") == "websocket": return None # Strip out search params, get relative path. + path = request.path path = path.partition("?")[0] relpath = str(Path(path).relative_to("/")) if relpath == ".": @@ -415,9 +432,9 @@ async def viser_http_server( source_path = http_server_root / relpath if not source_path.exists(): - return (http.HTTPStatus.NOT_FOUND, {}, b"404") # type: ignore + return Response(http.HTTPStatus.NOT_FOUND, "NOT FOUND", Headers()) - use_gzip = "gzip" in request_headers.get("Accept-Encoding", "") + use_gzip = "gzip" in request.headers.get("Accept-Encoding", "") # First, try some known MIME types. Using guess_type() can cause # problems for Javascript on some Windows machines. @@ -462,42 +479,44 @@ async def viser_http_server( response_payload = file_cache[source_path] # Try to read + send over file. - return (http.HTTPStatus.OK, response_headers, response_payload) - - for _ in range(1000): - try: - serve_future = websockets.server.serve( - serve, - host, - port, - # Compression can be too slow for our use cases. - compression=None, - process_request=( - viser_http_server if http_server_root is not None else None - ), - ) - self._ws_server = serve_future.ws_server - event_loop.run_until_complete(serve_future) - break - except OSError: # Port not available. - port += 1 - continue - - if self._ws_server is None: - raise RuntimeError("Failed to bind to port!") - - self._port = port - - ready_sem.release() - event_loop.run_forever() - - # This will run only when the event loop ends, which happens when the - # websocket server is closed. + return Response( + http.HTTPStatus.OK, + "OK", + websockets.datastructures.Headers(**response_headers), + response_payload, + ) + # return (http.HTTPStatus.OK, response_headers, response_payload) + + async def start_server() -> None: + port_attempt = port + for _ in range(1000): + try: + async with websockets.asyncio.server.serve( + ws_handler, + host, + port_attempt, + # Compression can be too slow for our use cases. + compression=None, + process_request=( + viser_http_server if http_server_root is not None else None + ), + ) as serve_future: + assert serve_future.server is not None + self._port = port_attempt + ready_sem.release() + assert self._stop_event is not None + await self._stop_event.wait() + return + except OSError: # Port not available. + port_attempt += 1 + continue + + event_loop.run_until_complete(start_server()) rich.print("[bold](viser)[/bold] Server stopped") async def _message_producer( - websocket: WebSocketServerProtocol, + websocket: ServerConnection, buffer: AsyncMessageBuffer, client_id: int, client_api_version: Literal[0, 1], @@ -522,7 +541,7 @@ async def _message_producer( async def _message_consumer( - websocket: WebSocketServerProtocol, + websocket: ServerConnection, handle_message: Callable[[Message], None], message_class: type[Message], ) -> None: diff --git a/tests/test_server_stop.py b/tests/test_server_stop.py new file mode 100644 index 000000000..2dac5ab36 --- /dev/null +++ b/tests/test_server_stop.py @@ -0,0 +1,27 @@ +import socket +import time + +import viser +import viser._client_autobuild + + +def test_server_port_is_freed(): + # Mock the client autobuild to avoid building the client. + viser._client_autobuild.ensure_client_is_built = lambda: None + + server = viser.ViserServer() + original_port = server.get_port() + + # Assert that the port is not free. + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + result = sock.connect_ex(("localhost", original_port)) + assert result == 0 + sock.close() + server.stop() + + time.sleep(0.05) + + # Assert that the port is now free. + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + result = sock.connect_ex(("localhost", original_port)) + assert result != 0