Skip to content

Commit

Permalink
Use new websockets API, fix .stop() (#313)
Browse files Browse the repository at this point in the history
* Fix `server.stop()`

* HTTP server fixes

* Format + basic test

* Fix Python 3.8

* Don't build client in test

* Add sleep to server stop test
  • Loading branch information
brentyi authored Nov 4, 2024
1 parent 9ef686f commit fa1f12e
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 55 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [
"Operating System :: OS Independent"
]
dependencies = [
"websockets>=10.4",
"websockets>=13.1",
"numpy>=1.0.0",
"msgspec>=0.18.6",
"imageio>=2.0.0",
Expand Down
127 changes: 73 additions & 54 deletions src/viser/infra/_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dataclasses
import gzip
import http
import logging
import mimetypes
import queue
import threading
Expand All @@ -16,12 +17,13 @@

import msgspec
import rich
import websockets.connection
import websockets.asyncio.server
import websockets.datastructures
import websockets.exceptions
import websockets.server
from typing_extensions import Literal, assert_never, override
from websockets.legacy.server import WebSocketServerProtocol
from websockets import Headers
from websockets.asyncio.server import ServerConnection
from websockets.http11 import Request, Response

from ._async_message_buffer import AsyncMessageBuffer
from ._messages import Message
Expand Down Expand Up @@ -234,8 +236,9 @@ def __init__(
self._http_server_root = http_server_root
self._verbose = verbose
self._client_api_version: Literal[0, 1] = client_api_version
self._shutdown_event = threading.Event()
self._ws_server: websockets.WebSocketServer | None = None
self._background_event_loop: asyncio.AbstractEventLoop | None = None

self._stop_event: asyncio.Event | None = None

self._client_state_from_id: dict[int, _ClientHandleState] = {}

Expand All @@ -258,9 +261,9 @@ def start(self) -> None:

def stop(self) -> None:
"""Stop the server."""
assert self._ws_server is not None
self._ws_server.close()
self._ws_server = None
assert self._background_event_loop is not None
assert self._stop_event is not None
self._background_event_loop.call_soon_threadsafe(self._stop_event.set)

def on_client_connect(
self, cb: Callable[[WebsockClientConnection], None | Coroutine]
Expand Down Expand Up @@ -298,6 +301,8 @@ def _background_worker(self, ready_sem: threading.Semaphore) -> None:
# Need to make a new event loop for notebook compatbility.
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
self._stop_event = asyncio.Event()
self._background_event_loop = event_loop
self._broadcast_buffer = AsyncMessageBuffer(
event_loop, persistent_messages=True
)
Expand All @@ -306,8 +311,20 @@ def _background_worker(self, ready_sem: threading.Semaphore) -> None:
connection_count = 0
total_connections = 0

async def serve(websocket: WebSocketServerProtocol) -> None:
"""Server loop, run once per connection."""
async def ws_handler(
connection: websockets.asyncio.server.ServerConnection,
) -> None:
"""Handler for websocket connections."""

# <Hack>
# Suppress errors for: https://github.com/python-websockets/websockets/issues/1513
# TODO: remove this when websockets behavior changes upstream.
class NoHttpErrors(logging.Filter):
def filter(self, record):
return not record.getMessage() == "opening handshake failed"

connection.logger.logger.addFilter(NoHttpErrors()) # type: ignore
# </Hack>

async with count_lock:
nonlocal connection_count
Expand Down Expand Up @@ -352,18 +369,18 @@ def handle_incoming(message: Message) -> None:
# and consumers (which receive messages).
await asyncio.gather(
_message_producer(
websocket,
connection,
client_state.message_buffer,
client_id,
self._client_api_version,
),
_message_producer(
websocket,
connection,
self._broadcast_buffer,
client_id,
self._client_api_version,
),
_message_consumer(websocket, handle_incoming, message_class),
_message_consumer(connection, handle_incoming, message_class),
)
except (
websockets.exceptions.ConnectionClosedOK,
Expand Down Expand Up @@ -397,16 +414,16 @@ def handle_incoming(message: Message) -> None:
file_cache: dict[Path, bytes] = {}
file_cache_gzipped: dict[Path, bytes] = {}

async def viser_http_server(
path: str, request_headers: websockets.datastructures.Headers
) -> (
tuple[http.HTTPStatus, websockets.datastructures.HeadersLike, bytes] | None
):
def viser_http_server(
connection: ServerConnection,
request: Request,
) -> Response | None:
# Ignore websocket packets.
if request_headers.get("Upgrade") == "websocket":
if request.headers.get("Upgrade") == "websocket":
return None

# Strip out search params, get relative path.
path = request.path
path = path.partition("?")[0]
relpath = str(Path(path).relative_to("/"))
if relpath == ".":
Expand All @@ -415,9 +432,9 @@ async def viser_http_server(

source_path = http_server_root / relpath
if not source_path.exists():
return (http.HTTPStatus.NOT_FOUND, {}, b"404") # type: ignore
return Response(http.HTTPStatus.NOT_FOUND, "NOT FOUND", Headers())

use_gzip = "gzip" in request_headers.get("Accept-Encoding", "")
use_gzip = "gzip" in request.headers.get("Accept-Encoding", "")

# First, try some known MIME types. Using guess_type() can cause
# problems for Javascript on some Windows machines.
Expand Down Expand Up @@ -462,42 +479,44 @@ async def viser_http_server(
response_payload = file_cache[source_path]

# Try to read + send over file.
return (http.HTTPStatus.OK, response_headers, response_payload)

for _ in range(1000):
try:
serve_future = websockets.server.serve(
serve,
host,
port,
# Compression can be too slow for our use cases.
compression=None,
process_request=(
viser_http_server if http_server_root is not None else None
),
)
self._ws_server = serve_future.ws_server
event_loop.run_until_complete(serve_future)
break
except OSError: # Port not available.
port += 1
continue

if self._ws_server is None:
raise RuntimeError("Failed to bind to port!")

self._port = port

ready_sem.release()
event_loop.run_forever()

# This will run only when the event loop ends, which happens when the
# websocket server is closed.
return Response(
http.HTTPStatus.OK,
"OK",
websockets.datastructures.Headers(**response_headers),
response_payload,
)
# return (http.HTTPStatus.OK, response_headers, response_payload)

async def start_server() -> None:
port_attempt = port
for _ in range(1000):
try:
async with websockets.asyncio.server.serve(
ws_handler,
host,
port_attempt,
# Compression can be too slow for our use cases.
compression=None,
process_request=(
viser_http_server if http_server_root is not None else None
),
) as serve_future:
assert serve_future.server is not None
self._port = port_attempt
ready_sem.release()
assert self._stop_event is not None
await self._stop_event.wait()
return
except OSError: # Port not available.
port_attempt += 1
continue

event_loop.run_until_complete(start_server())
rich.print("[bold](viser)[/bold] Server stopped")


async def _message_producer(
websocket: WebSocketServerProtocol,
websocket: ServerConnection,
buffer: AsyncMessageBuffer,
client_id: int,
client_api_version: Literal[0, 1],
Expand All @@ -522,7 +541,7 @@ async def _message_producer(


async def _message_consumer(
websocket: WebSocketServerProtocol,
websocket: ServerConnection,
handle_message: Callable[[Message], None],
message_class: type[Message],
) -> None:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_server_stop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import socket
import time

import viser
import viser._client_autobuild


def test_server_port_is_freed():
# Mock the client autobuild to avoid building the client.
viser._client_autobuild.ensure_client_is_built = lambda: None

server = viser.ViserServer()
original_port = server.get_port()

# Assert that the port is not free.
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(("localhost", original_port))
assert result == 0
sock.close()
server.stop()

time.sleep(0.05)

# Assert that the port is now free.
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(("localhost", original_port))
assert result != 0

0 comments on commit fa1f12e

Please sign in to comment.