Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new websockets API, fix .stop() #313

Merged
merged 6 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [
"Operating System :: OS Independent"
]
dependencies = [
"websockets>=10.4",
"websockets>=13.1",
"numpy>=1.0.0",
"msgspec>=0.18.6",
"imageio>=2.0.0",
Expand Down
127 changes: 73 additions & 54 deletions src/viser/infra/_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dataclasses
import gzip
import http
import logging
import mimetypes
import queue
import threading
Expand All @@ -16,12 +17,13 @@

import msgspec
import rich
import websockets.connection
import websockets.asyncio.server
import websockets.datastructures
import websockets.exceptions
import websockets.server
from typing_extensions import Literal, assert_never, override
from websockets.legacy.server import WebSocketServerProtocol
from websockets import Headers
from websockets.asyncio.server import ServerConnection
from websockets.http11 import Request, Response

from ._async_message_buffer import AsyncMessageBuffer
from ._messages import Message
Expand Down Expand Up @@ -234,8 +236,9 @@ def __init__(
self._http_server_root = http_server_root
self._verbose = verbose
self._client_api_version: Literal[0, 1] = client_api_version
self._shutdown_event = threading.Event()
self._ws_server: websockets.WebSocketServer | None = None
self._background_event_loop: asyncio.AbstractEventLoop | None = None

self._stop_event: asyncio.Event | None = None

self._client_state_from_id: dict[int, _ClientHandleState] = {}

Expand All @@ -258,9 +261,9 @@ def start(self) -> None:

def stop(self) -> None:
"""Stop the server."""
assert self._ws_server is not None
self._ws_server.close()
self._ws_server = None
assert self._background_event_loop is not None
assert self._stop_event is not None
self._background_event_loop.call_soon_threadsafe(self._stop_event.set)

def on_client_connect(
self, cb: Callable[[WebsockClientConnection], None | Coroutine]
Expand Down Expand Up @@ -298,6 +301,8 @@ def _background_worker(self, ready_sem: threading.Semaphore) -> None:
# Need to make a new event loop for notebook compatbility.
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
self._stop_event = asyncio.Event()
self._background_event_loop = event_loop
self._broadcast_buffer = AsyncMessageBuffer(
event_loop, persistent_messages=True
)
Expand All @@ -306,8 +311,20 @@ def _background_worker(self, ready_sem: threading.Semaphore) -> None:
connection_count = 0
total_connections = 0

async def serve(websocket: WebSocketServerProtocol) -> None:
"""Server loop, run once per connection."""
async def ws_handler(
connection: websockets.asyncio.server.ServerConnection,
) -> None:
"""Handler for websocket connections."""

# <Hack>
# Suppress errors for: https://github.com/python-websockets/websockets/issues/1513
# TODO: remove this when websockets behavior changes upstream.
class NoHttpErrors(logging.Filter):
def filter(self, record):
return not record.getMessage() == "opening handshake failed"

connection.logger.logger.addFilter(NoHttpErrors()) # type: ignore
# </Hack>

async with count_lock:
nonlocal connection_count
Expand Down Expand Up @@ -352,18 +369,18 @@ def handle_incoming(message: Message) -> None:
# and consumers (which receive messages).
await asyncio.gather(
_message_producer(
websocket,
connection,
client_state.message_buffer,
client_id,
self._client_api_version,
),
_message_producer(
websocket,
connection,
self._broadcast_buffer,
client_id,
self._client_api_version,
),
_message_consumer(websocket, handle_incoming, message_class),
_message_consumer(connection, handle_incoming, message_class),
)
except (
websockets.exceptions.ConnectionClosedOK,
Expand Down Expand Up @@ -397,16 +414,16 @@ def handle_incoming(message: Message) -> None:
file_cache: dict[Path, bytes] = {}
file_cache_gzipped: dict[Path, bytes] = {}

async def viser_http_server(
path: str, request_headers: websockets.datastructures.Headers
) -> (
tuple[http.HTTPStatus, websockets.datastructures.HeadersLike, bytes] | None
):
def viser_http_server(
connection: ServerConnection,
request: Request,
) -> Response | None:
# Ignore websocket packets.
if request_headers.get("Upgrade") == "websocket":
if request.headers.get("Upgrade") == "websocket":
return None

# Strip out search params, get relative path.
path = request.path
path = path.partition("?")[0]
relpath = str(Path(path).relative_to("/"))
if relpath == ".":
Expand All @@ -415,9 +432,9 @@ async def viser_http_server(

source_path = http_server_root / relpath
if not source_path.exists():
return (http.HTTPStatus.NOT_FOUND, {}, b"404") # type: ignore
return Response(http.HTTPStatus.NOT_FOUND, "NOT FOUND", Headers())

use_gzip = "gzip" in request_headers.get("Accept-Encoding", "")
use_gzip = "gzip" in request.headers.get("Accept-Encoding", "")

# First, try some known MIME types. Using guess_type() can cause
# problems for Javascript on some Windows machines.
Expand Down Expand Up @@ -462,42 +479,44 @@ async def viser_http_server(
response_payload = file_cache[source_path]

# Try to read + send over file.
return (http.HTTPStatus.OK, response_headers, response_payload)

for _ in range(1000):
try:
serve_future = websockets.server.serve(
serve,
host,
port,
# Compression can be too slow for our use cases.
compression=None,
process_request=(
viser_http_server if http_server_root is not None else None
),
)
self._ws_server = serve_future.ws_server
event_loop.run_until_complete(serve_future)
break
except OSError: # Port not available.
port += 1
continue

if self._ws_server is None:
raise RuntimeError("Failed to bind to port!")

self._port = port

ready_sem.release()
event_loop.run_forever()

# This will run only when the event loop ends, which happens when the
# websocket server is closed.
return Response(
http.HTTPStatus.OK,
"OK",
websockets.datastructures.Headers(**response_headers),
response_payload,
)
# return (http.HTTPStatus.OK, response_headers, response_payload)

async def start_server() -> None:
port_attempt = port
for _ in range(1000):
try:
async with websockets.asyncio.server.serve(
ws_handler,
host,
port_attempt,
# Compression can be too slow for our use cases.
compression=None,
process_request=(
viser_http_server if http_server_root is not None else None
),
) as serve_future:
assert serve_future.server is not None
self._port = port_attempt
ready_sem.release()
assert self._stop_event is not None
await self._stop_event.wait()
return
except OSError: # Port not available.
port_attempt += 1
continue

event_loop.run_until_complete(start_server())
rich.print("[bold](viser)[/bold] Server stopped")


async def _message_producer(
websocket: WebSocketServerProtocol,
websocket: ServerConnection,
buffer: AsyncMessageBuffer,
client_id: int,
client_api_version: Literal[0, 1],
Expand All @@ -522,7 +541,7 @@ async def _message_producer(


async def _message_consumer(
websocket: WebSocketServerProtocol,
websocket: ServerConnection,
handle_message: Callable[[Message], None],
message_class: type[Message],
) -> None:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_server_stop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import socket
import time

import viser
import viser._client_autobuild


def test_server_port_is_freed():
# Mock the client autobuild to avoid building the client.
viser._client_autobuild.ensure_client_is_built = lambda: None

server = viser.ViserServer()
original_port = server.get_port()

# Assert that the port is not free.
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(("localhost", original_port))
assert result == 0
sock.close()
server.stop()

time.sleep(0.05)

# Assert that the port is now free.
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(("localhost", original_port))
assert result != 0
Loading