From 0cf6fadf493df646e107cebeb85595ef7b63f68a Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 15 Oct 2024 01:54:46 -0700 Subject: [PATCH 1/4] Concurrency refactor --- src/viser/_gui_api.py | 16 +++- src/viser/_gui_handles.py | 75 ++++++++++++----- src/viser/_scene_api.py | 49 +++++++---- src/viser/_scene_handles.py | 39 +++++++-- src/viser/_viser.py | 94 ++++++++++++++------- src/viser/infra/_async_message_buffer.py | 21 ++++- src/viser/infra/_infra.py | 100 ++++++++++------------- 7 files changed, 254 insertions(+), 140 deletions(-) diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index 553df47a..8e60c88d 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -1,11 +1,13 @@ from __future__ import annotations +import asyncio import builtins import colorsys import dataclasses import functools import threading import time +from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import ( @@ -183,12 +185,14 @@ def __init__( self, owner: ViserServer | ClientHandle, # Who do I belong to? thread_executor: ThreadPoolExecutor, + event_loop: AbstractEventLoop, ) -> None: from ._viser import ViserServer self._owner = owner """Entity that owns this API.""" self._thread_executor = thread_executor + self._event_loop = event_loop self._websock_interface = ( owner._websock_server @@ -217,7 +221,7 @@ def __init__( self._handle_file_transfer_part, ) - def _handle_gui_updates( + async def _handle_gui_updates( self, client_id: ClientId, message: _messages.GuiUpdateMessage ) -> None: """Callback for handling GUI messages.""" @@ -273,7 +277,10 @@ def _handle_gui_updates( else: assert False - cb(GuiEvent(client, client_id, handle)) + if asyncio.iscoroutinefunction(cb): + self._event_loop.create_task(cb(GuiEvent(client, client_id, handle))) + else: + self._thread_executor.submit(cb, GuiEvent(client, client_id, handle)) if handle_state.sync_cb is not None: handle_state.sync_cb(client_id, updates_cast) @@ -355,7 +362,10 @@ def _handle_file_transfer_part( else: assert False - cb(GuiEvent(client, client_id, handle)) + if asyncio.iscoroutinefunction(cb): + self._event_loop.create_task(cb(GuiEvent(client, client_id, handle))) + else: + self._thread_executor.submit(cb, GuiEvent(client, client_id, handle)) def _get_container_uuid(self) -> str: """Get container ID associated with the current thread.""" diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index b621f23e..dc3fb201 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -1,11 +1,13 @@ from __future__ import annotations +import asyncio import base64 import dataclasses import re import time import uuid import warnings +from collections.abc import Coroutine from functools import cached_property from pathlib import Path from typing import ( @@ -51,7 +53,7 @@ GuiVector2Props, GuiVector3Props, ) -from ._scene_api import _encode_image_binary +from ._scene_api import NoneOrCoroutine, _encode_image_binary from .infra import ClientId if TYPE_CHECKING: @@ -63,6 +65,7 @@ T = TypeVar("T") TGuiHandle = TypeVar("TGuiHandle", bound="_GuiInputHandle") +NoneOrCoroutine = TypeVar("NoneOrCoroutine", None, Coroutine) def _make_uuid() -> str: @@ -96,7 +99,7 @@ class _GuiHandleState(Generic[T]): """Container that this GUI input was placed into.""" update_timestamp: float = 0.0 - update_cb: list[Callable[[GuiEvent], None]] = dataclasses.field( + update_cb: list[Callable[[GuiEvent], None | Coroutine]] = dataclasses.field( default_factory=list ) """Registered functions to call when this input is updated.""" @@ -220,17 +223,19 @@ def value(self, value: T | np.ndarray) -> None: # Call update callbacks. for cb in self._impl.update_cb: - # Pushing callbacks into separate threads helps prevent deadlocks when we - # have a lock in a callback. TODO: revisit other callbacks. - self._impl.gui_api._thread_executor.submit( - lambda: cb( + if asyncio.iscoroutinefunction(cb): + self._impl.gui_api._event_loop.create_task( + cb(GuiEvent(client_id=None, client=None, target=self)) + ) + else: + self._impl.gui_api._thread_executor.submit( + cb, GuiEvent( client_id=None, client=None, target=self, - ) + ), ) - ) @property def update_timestamp(self) -> float: @@ -255,11 +260,16 @@ class GuiInputHandle(_GuiInputHandle[T], Generic[T]): """ def on_update( - self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], Any] - ) -> Callable[[GuiEvent[TGuiHandle]], None]: - """Attach a function to call when a GUI input is updated. Callbacks stack (need - to be manually removed via :meth:`remove_update_callback()`) and will be called - from a thread.""" + self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine] + ) -> Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine]: + """Attach a function to call when a GUI input is updated. + + Note: + - If `func` is a regular function (defined with `def`), it will be executed in a thread pool. + - If `func` is an async function (defined with `async def`), it will be executed in the event loop. + + Using async functions may help reduce race conditions in certain scenarios. + """ self._impl.update_cb.append(func) return func @@ -396,9 +406,16 @@ class GuiButtonHandle(_GuiInputHandle[bool]): """ def on_click( - self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None] - ) -> Callable[[GuiEvent[TGuiHandle]], None]: - """Attach a function to call when a button is pressed. Happens in a thread.""" + self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine] + ) -> Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine]: + """Attach a function to call when a button is pressed. + + Note: + - If `func` is a regular function (defined with `def`), it will be executed in a thread pool. + - If `func` is an async function (defined with `async def`), it will be executed in the event loop. + + Using async functions may help reduce race conditions in certain scenarios. + """ self._impl.update_cb.append(func) return func @@ -425,9 +442,16 @@ class GuiUploadButtonHandle(_GuiInputHandle[UploadedFile]): """ def on_upload( - self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None] - ) -> Callable[[GuiEvent[TGuiHandle]], None]: - """Attach a function to call when a button is pressed. Happens in a thread.""" + self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine] + ) -> Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine]: + """Attach a function to call when a file is uploaded. + + Note: + - If `func` is a regular function (defined with `def`), it will be executed in a thread pool. + - If `func` is an async function (defined with `async def`), it will be executed in the event loop. + + Using async functions may help reduce race conditions in certain scenarios. + """ self._impl.update_cb.append(func) return func @@ -442,9 +466,16 @@ class GuiButtonGroupHandle(_GuiInputHandle[str], GuiButtonGroupProps): """ def on_click( - self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None] - ) -> Callable[[GuiEvent[TGuiHandle]], None]: - """Attach a function to call when a button is pressed. Happens in a thread.""" + self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine] + ) -> Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine]: + """Attach a function to call when a button in the group is clicked. + + Note: + - If `func` is a regular function (defined with `def`), it will be executed in a thread pool. + - If `func` is an async function (defined with `async def`), it will be executed in the event loop. + + Using async functions may help reduce race conditions in certain scenarios. + """ self._impl.update_cb.append(func) return func diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 2487272f..02c6d9b0 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -1,8 +1,10 @@ from __future__ import annotations +import asyncio import io import time import warnings +from collections.abc import Coroutine from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Callable, Tuple, TypeVar, Union, cast, get_args @@ -39,6 +41,7 @@ SplineCubicBezierHandle, SpotLightHandle, TransformControlsHandle, + _ClickableSceneNodeHandle, _TransformControlsState, colors_to_uint8, ) @@ -57,6 +60,8 @@ Tuple[int, int, int], Tuple[float, float, float], np.ndarray ] +NoneOrCoroutine = TypeVar("NoneOrCoroutine", None, Coroutine) + def _encode_rgb(rgb: RgbTupleOrArray) -> tuple[int, int, int]: if isinstance(rgb, np.ndarray): @@ -115,9 +120,13 @@ def __init__( self, owner: ViserServer | ClientHandle, # Who do I belong to? thread_executor: ThreadPoolExecutor, + event_loop: asyncio.AbstractEventLoop, ) -> None: from ._viser import ViserServer + self._thread_executor = thread_executor + self._event_loop = event_loop + self._owner = owner """Entity that owns this API.""" @@ -133,8 +142,10 @@ def __init__( ] = {} self._handle_from_node_name: dict[str, SceneNodeHandle] = {} - self._scene_pointer_cb: Callable[[ScenePointerEvent], None] | None = None - self._scene_pointer_done_cb: Callable[[], None] = lambda: None + self._scene_pointer_cb: ( + Callable[[ScenePointerEvent], None | Coroutine] | None + ) = None + self._scene_pointer_done_cb: Callable[[], None | Coroutine] = lambda: None self._scene_pointer_event_type: _messages.ScenePointerEventType | None = None # Set up world axes handle. @@ -159,8 +170,6 @@ def __init__( self._handle_scene_pointer_updates, ) - self._thread_executor = thread_executor - def set_up_direction( self, direction: Literal["+x", "+y", "+z", "-x", "-y", "-z"] @@ -1543,7 +1552,7 @@ def _get_client_handle(self, client_id: ClientId) -> ClientHandle: assert client_id == self._owner.client_id return self._owner - def _handle_transform_controls_updates( + async def _handle_transform_controls_updates( self, client_id: ClientId, message: _messages.TransformControlsUpdateMessage ) -> None: """Callback for handling transform gizmo messages.""" @@ -1560,11 +1569,14 @@ def _handle_transform_controls_updates( # Trigger callbacks. for cb in handle._impl_aux.update_cb: - cb(handle) + if asyncio.iscoroutinefunction(cb): + await cb(handle) + else: + self._thread_executor.submit(cb, handle) if handle._impl_aux.sync_cb is not None: handle._impl_aux.sync_cb(client_id, handle) - def _handle_node_click_updates( + async def _handle_node_click_updates( self, client_id: ClientId, message: _messages.SceneNodeClickMessage ) -> None: """Callback for handling click messages.""" @@ -1576,15 +1588,18 @@ def _handle_node_click_updates( client=self._get_client_handle(client_id), client_id=client_id, event="click", - target=handle, + target=cast(_ClickableSceneNodeHandle, handle), ray_origin=message.ray_origin, ray_direction=message.ray_direction, screen_pos=message.screen_pos, instance_index=message.instance_index, ) - cb(event) # type: ignore + if asyncio.iscoroutinefunction(cb): + await cb(event) + else: + self._thread_executor.submit(cb, event) - def _handle_scene_pointer_updates( + async def _handle_scene_pointer_updates( self, client_id: ClientId, message: _messages.ScenePointerMessage ): """Callback for handling click messages.""" @@ -1599,7 +1614,10 @@ def _handle_scene_pointer_updates( # Call the callback if it exists, and the after-run callback. if self._scene_pointer_cb is None: return - self._scene_pointer_cb(event) + if asyncio.iscoroutinefunction(self._scene_pointer_cb): + await self._scene_pointer_cb(event) + else: + self._thread_executor.submit(self._scene_pointer_cb, event) def on_pointer_event( self, event_type: Literal["click", "rect-select"] @@ -1654,8 +1672,8 @@ def decorator( def on_pointer_callback_removed( self, - func: Callable[[], None], - ) -> Callable[[], None]: + func: Callable[[], NoneOrCoroutine], + ) -> Callable[[], NoneOrCoroutine]: """Add a callback to run automatically when the callback for a scene pointer event is removed. This will be triggered exactly once, either manually (via :meth:`remove_pointer_callback()`) or automatically (if @@ -1690,7 +1708,10 @@ def remove_pointer_callback( self._owner.flush() # Run cleanup callback. - self._scene_pointer_done_cb() + if asyncio.iscoroutinefunction(self._scene_pointer_done_cb): + self._event_loop.create_task(self._scene_pointer_done_cb()) + else: + self._scene_pointer_done_cb() # Reset the callback and event type, on the python side. self._scene_pointer_cb = None diff --git a/src/viser/_scene_handles.py b/src/viser/_scene_handles.py index acc22df4..47929640 100644 --- a/src/viser/_scene_handles.py +++ b/src/viser/_scene_handles.py @@ -3,6 +3,7 @@ import copy import dataclasses import warnings +from collections.abc import Coroutine from functools import cached_property from typing import ( TYPE_CHECKING, @@ -132,7 +133,7 @@ class _SceneNodeHandleState: ) visible: bool = True click_cb: list[ - Callable[[SceneNodePointerEvent[_ClickableSceneNodeHandle]], None] + Callable[[SceneNodePointerEvent[_ClickableSceneNodeHandle]], None | Coroutine] ] = dataclasses.field(default_factory=list) removed: bool = False @@ -272,12 +273,22 @@ class SceneNodePointerEvent(Generic[TSceneNodeHandle]): """Instance ID of the clicked object, if applicable. Currently this is `None` for all objects except for the output of :meth:`SceneApi.add_batched_axes()`.""" +NoneOrCoroutine = TypeVar("NoneOrCoroutine", None, Coroutine) + + class _ClickableSceneNodeHandle(SceneNodeHandle): def on_click( self: Self, - func: Callable[[SceneNodePointerEvent[Self]], None], - ) -> Callable[[SceneNodePointerEvent[Self]], None]: - """Attach a callback for when a scene node is clicked.""" + func: Callable[[SceneNodePointerEvent[Self]], NoneOrCoroutine], + ) -> Callable[[SceneNodePointerEvent[Self]], NoneOrCoroutine]: + """Attach a callback for when a scene node is clicked. + + The callback can be either a standard function or an async function: + - Standard functions (def) will be executed in a threadpool. + - Async functions (async def) will be executed in the event loop. + + Using async functions may help reduce race conditions in certain scenarios. + """ self._impl.api._websock_interface.queue_message( _messages.SetSceneNodeClickableMessage(self._impl.name, True) ) @@ -285,7 +296,10 @@ def on_click( self._impl.click_cb = [] self._impl.click_cb.append( cast( - Callable[[SceneNodePointerEvent[_ClickableSceneNodeHandle]], None], func + Callable[ + [SceneNodePointerEvent[_ClickableSceneNodeHandle]], None | Coroutine + ], + func, ) ) return func @@ -528,7 +542,7 @@ class LabelHandle( @dataclasses.dataclass class _TransformControlsState: last_updated: float - update_cb: list[Callable[[TransformControlsHandle], None]] + update_cb: list[Callable[[TransformControlsHandle], None | Coroutine]] sync_cb: None | Callable[[ClientId, TransformControlsHandle], None] = None @@ -548,9 +562,16 @@ def update_timestamp(self) -> float: return self._impl_aux.last_updated def on_update( - self, func: Callable[[TransformControlsHandle], None] - ) -> Callable[[TransformControlsHandle], None]: - """Attach a callback for when the gizmo is moved.""" + self, func: Callable[[TransformControlsHandle], NoneOrCoroutine] + ) -> Callable[[TransformControlsHandle], NoneOrCoroutine]: + """Attach a callback for when the gizmo is moved. + + The callback can be either a standard function or an async function: + - Standard functions (def) will be executed in a threadpool. + - Async functions (async def) will be executed in the event loop. + + Using async functions may help reduce race conditions in certain scenarios. + """ self._impl_aux.update_cb.append(func) return func diff --git a/src/viser/_viser.py b/src/viser/_viser.py index 1ff65aa9..94fe4d18 100644 --- a/src/viser/_viser.py +++ b/src/viser/_viser.py @@ -1,13 +1,16 @@ from __future__ import annotations +import asyncio import dataclasses import io import mimetypes import threading import time import warnings +from collections.abc import Coroutine +from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager +from typing import TYPE_CHECKING, Any, Callable, ContextManager, TypeVar, cast import imageio.v3 as iio import numpy as np @@ -22,7 +25,7 @@ from . import transforms as tf from ._gui_api import Color, GuiApi, _make_uuid from ._notification_handle import NotificationHandle, _NotificationHandleState -from ._scene_api import SceneApi, cast_vector +from ._scene_api import NoneOrCoroutine, SceneApi, cast_vector from ._tunnel import ViserTunnel from .infra._infra import RecordHandle @@ -75,7 +78,7 @@ class _CameraHandleState: look_at: npt.NDArray[np.float64] up_direction: npt.NDArray[np.float64] update_timestamp: float - camera_cb: list[Callable[[CameraHandle], None]] + camera_cb: list[Callable[[CameraHandle], None | Coroutine]] class CameraHandle: @@ -249,8 +252,8 @@ def up_direction( ) def on_update( - self, callback: Callable[[CameraHandle], None] - ) -> Callable[[CameraHandle], None]: + self, callback: Callable[[CameraHandle], NoneOrCoroutine] + ) -> Callable[[CameraHandle], NoneOrCoroutine]: """Attach a callback to run when a new camera message is received.""" self._state.camera_cb.append(callback) return callback @@ -307,6 +310,9 @@ def got_render_cb( return out +NoneOrCoroutine = TypeVar("NoneOrCoroutine", None, Coroutine) + + # Don't inherit from _BackwardsCompatibilityShim during type checking, because # this will unnecessarily suppress type errors. (from the overriding of # __getattr__). @@ -330,13 +336,9 @@ def __init__( self._viser_server = server # Public attributes. - self.scene: SceneApi = SceneApi( - self, thread_executor=server._websock_server._thread_executor - ) + self.scene: SceneApi = SceneApi(self, thread_executor=server._thread_executor) """Handle for interacting with the 3D scene.""" - self.gui: GuiApi = GuiApi( - self, thread_executor=server._websock_server._thread_executor - ) + self.gui: GuiApi = GuiApi(self, thread_executor=server._thread_executor) """Handle for interacting with the GUI.""" self.client_id: int = conn.client_id """Unique ID for this client.""" @@ -494,16 +496,20 @@ def __init__( self._connection = server self._connected_clients: dict[int, ClientHandle] = {} self._client_lock = threading.Lock() - self._client_connect_cb: list[Callable[[ClientHandle], None]] = [] - self._client_disconnect_cb: list[Callable[[ClientHandle], None]] = [] + self._client_connect_cb: list[Callable[[ClientHandle], None | Coroutine]] = [] + self._client_disconnect_cb: list[ + Callable[[ClientHandle], None | Coroutine] + ] = [] + + self._thread_executor = ThreadPoolExecutor(max_workers=32) # For new clients, register and add a handler for camera messages. @server.on_client_connect - def _(conn: infra.WebsockClientConnection) -> None: + async def _(conn: infra.WebsockClientConnection) -> None: client = ClientHandle(conn, server=self) first = True - def handle_camera_message( + async def handle_camera_message( client_id: infra.ClientId, message: _messages.ViewerCameraMessage ) -> None: nonlocal first @@ -530,39 +536,58 @@ def handle_camera_message( with self._client_lock: self._connected_clients[conn.client_id] = client for cb in self._client_connect_cb: - cb(client) + if asyncio.iscoroutinefunction(cb): + await cb(client) + else: + self._thread_executor.submit(cb, client) for camera_cb in client.camera._state.camera_cb: - camera_cb(client.camera) + if asyncio.iscoroutinefunction(camera_cb): + await camera_cb(client.camera) + else: + self._thread_executor.submit(camera_cb, client.camera) conn.register_handler(_messages.ViewerCameraMessage, handle_camera_message) # Remove clients when they disconnect. @server.on_client_disconnect - def _(conn: infra.WebsockClientConnection) -> None: + async def _(conn: infra.WebsockClientConnection) -> None: with self._client_lock: if conn.client_id not in self._connected_clients: return handle = self._connected_clients.pop(conn.client_id) for cb in self._client_disconnect_cb: - cb(handle) + if asyncio.iscoroutinefunction(cb): + await cb(handle) + else: + self._thread_executor.submit(cb, handle) # Start the server. server.start() + self._event_loop = server._broadcast_buffer.event_loop - self.scene: SceneApi = SceneApi(self, thread_executor=server._thread_executor) + self.scene: SceneApi = SceneApi( + self, thread_executor=self._thread_executor, event_loop=self._event_loop + ) """Handle for interacting with the 3D scene.""" - self.gui: GuiApi = GuiApi(self, thread_executor=server._thread_executor) + self.gui: GuiApi = GuiApi( + self, thread_executor=self._thread_executor, event_loop=self._event_loop + ) """Handle for interacting with the GUI.""" server.register_handler( _messages.ShareUrlDisconnect, lambda client_id, msg: self.disconnect_share_url(), ) + + def request_share_url_no_return() -> None: # To suppress type error. + self.request_share_url() + server.register_handler( - _messages.ShareUrlRequest, lambda client_id, msg: self.request_share_url() + _messages.ShareUrlRequest, + lambda client_id, msg: cast(None, request_share_url_no_return()), ) # Form status print. @@ -686,8 +711,8 @@ def get_clients(self) -> dict[int, ClientHandle]: return self._connected_clients.copy() def on_client_connect( - self, cb: Callable[[ClientHandle], None] - ) -> Callable[[ClientHandle], None]: + self, cb: Callable[[ClientHandle], NoneOrCoroutine] + ) -> Callable[[ClientHandle], NoneOrCoroutine]: """Attach a callback to run for newly connected clients.""" with self._client_lock: clients = self._connected_clients.copy().values() @@ -702,12 +727,16 @@ def on_client_connect( # This makes sure that the the callback is applied to any clients that # connect between the two lines. for client in clients: - cb(client) - return cb + if asyncio.iscoroutinefunction(cb): + self._event_loop.create_task(cb(client)) + else: + self._thread_executor.submit(cb, client) + + return cb # type: ignore def on_client_disconnect( - self, cb: Callable[[ClientHandle], None] - ) -> Callable[[ClientHandle], None]: + self, cb: Callable[[ClientHandle], NoneOrCoroutine] + ) -> Callable[[ClientHandle], NoneOrCoroutine]: """Attach a callback to run when clients disconnect.""" self._client_disconnect_cb.append(cb) return cb @@ -743,9 +772,14 @@ def send_file_download( for client in self.get_clients().values(): client.send_file_download(filename, content, chunk_size) + def get_event_loop(self) -> asyncio.AbstractEventLoop: + """Get the asyncio event loop used by the Viser background thread. This + can be useful for safe concurrent operations.""" + return self._event_loop + def _start_scene_recording(self) -> RecordHandle: - """Start recording outgoing messages for playback or - embedding. Includes only the scene. + """Start recording outgoing messages for playback or embedding. + Includes only the scene. **Work-in-progress.** This API may be changed or removed. """ diff --git a/src/viser/infra/_async_message_buffer.py b/src/viser/infra/_async_message_buffer.py index 48879cfa..235d8709 100644 --- a/src/viser/infra/_async_message_buffer.py +++ b/src/viser/infra/_async_message_buffer.py @@ -30,6 +30,7 @@ class AsyncMessageBuffer: max_window_size: int = 128 window_duration_sec: float = 1.0 / 60.0 done: bool = False + atomic_counter: int = 0 def remove_from_buffer(self, match_fn: Callable[[Message], bool]) -> None: """Remove messages that match some condition.""" @@ -66,7 +67,19 @@ def push(self, message: Message) -> None: self.id_from_redundancy_key[redundancy_key] = new_message_id # Pulse message event to notify consumers that a new message is available. - self.event_loop.call_soon_threadsafe(self.message_event.set) + # But only do so if we're not in an atomic block. + if self.atomic_counter == 0: + self.event_loop.call_soon_threadsafe(self.message_event.set) + + def atomic_start(self) -> None: + """Start an atomic block. No new messages/windows should be sent.""" + self.atomic_counter += 1 + + def atomic_end(self) -> None: + """End an atomic block.""" + self.atomic_counter -= 1 + if self.atomic_counter == 0: + self.event_loop.call_soon_threadsafe(self.message_event.set) def flush(self) -> None: """Flush the message buffer; signals to yield a message window immediately.""" @@ -89,13 +102,15 @@ async def window_generator( are available.""" last_sent_id = -1 - flush_wait = asyncio.create_task(self.flush_event.wait()) + flush_wait = self.event_loop.create_task(self.flush_event.wait()) while not self.done: window: List[Message] = [] most_recent_message_id = self.message_counter - 1 while ( last_sent_id < most_recent_message_id and len(window) < self.max_window_size + # We should only be polling for new messages if we aren't in an atomic block. + and self.atomic_counter == 0 ): last_sent_id += 1 if self.persistent_messages: @@ -128,4 +143,4 @@ async def window_generator( del pending if flush_wait in done and not self.done: self.flush_event.clear() - flush_wait = asyncio.create_task(self.flush_event.wait()) + flush_wait = self.event_loop.create_task(self.flush_event.wait()) diff --git a/src/viser/infra/_infra.py b/src/viser/infra/_infra.py index b85a83c0..4e8c2b03 100644 --- a/src/viser/infra/_infra.py +++ b/src/viser/infra/_infra.py @@ -10,6 +10,7 @@ import queue import threading from asyncio.events import AbstractEventLoop +from collections.abc import Coroutine from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, Callable, Generator, NewType, TypeVar @@ -89,12 +90,10 @@ def end_and_serialize(self) -> bytes: class WebsockMessageHandler: """Mix-in for adding message handling to a class.""" - def __init__(self, thread_executor: ThreadPoolExecutor) -> None: - self._thread_executor = thread_executor + def __init__(self) -> None: self._incoming_handlers: dict[ - type[Message], list[Callable[[ClientId, Message], None]] + type[Message], list[Callable[[ClientId, Message], None | Coroutine]] ] = {} - self._atomic_lock = threading.Lock() self._queued_messages: queue.Queue = queue.Queue() self._locked_thread_id = -1 @@ -111,7 +110,7 @@ def start_recording(self, filter: Callable[[Message], bool]) -> RecordHandle: def register_handler( self, message_cls: type[TMessage], - callback: Callable[[ClientId, TMessage], Any], + callback: Callable[[ClientId, TMessage], None | Coroutine], ) -> None: """Register a handler for a particular message type.""" if message_cls not in self._incoming_handlers: @@ -121,7 +120,7 @@ def register_handler( def unregister_handler( self, message_cls: type[TMessage], - callback: Callable[[ClientId, TMessage], Any] | None = None, + callback: Callable[[ClientId, TMessage], None | Coroutine] | None = None, ): """Unregister a handler for a particular message type.""" assert ( @@ -132,34 +131,26 @@ def unregister_handler( else: self._incoming_handlers[message_cls].remove(callback) # type: ignore - def _handle_incoming_message(self, client_id: ClientId, message: Message) -> None: + async def _handle_incoming_message( + self, client_id: ClientId, message: Message + ) -> None: """Handle incoming messages.""" if type(message) in self._incoming_handlers: for cb in self._incoming_handlers[type(message)]: - cb(client_id, message) + if asyncio.iscoroutinefunction(cb): + await cb(client_id, message) + else: + cb(client_id, message) @abc.abstractmethod def get_message_buffer(self) -> AsyncMessageBuffer: ... def queue_message(self, message: Message) -> None: - """Wrapped method for sending messages safely.""" + """Wrapped method for sending messages.""" if self._record_handle is not None: self._record_handle._insert_message(message) - got_lock = self._atomic_lock.acquire(blocking=False) - if got_lock: - self.get_message_buffer().push(message) - self._atomic_lock.release() - else: - # Send when lock is acquirable, while retaining message order. - # This could be optimized! - self._queued_messages.put(message) - - def try_again() -> None: - with self._atomic_lock: - self.get_message_buffer().push(self._queued_messages.get()) - - self._thread_executor.submit(try_again) + self.get_message_buffer().push(message) @contextlib.contextmanager def atomic(self) -> Generator[None, None, None]: @@ -174,19 +165,9 @@ def atomic(self) -> Generator[None, None, None]: Context manager. """ # If called multiple times in the same thread, we ignore inner calls. - thread_id = threading.get_ident() - if thread_id == self._locked_thread_id: - got_lock = False - else: - self._atomic_lock.acquire() - self._locked_thread_id = thread_id - got_lock = True - + self.get_message_buffer().atomic_start() yield - - if got_lock: - self._atomic_lock.release() - self._locked_thread_id = -1 + self.get_message_buffer().atomic_end() class WebsockClientConnection(WebsockMessageHandler): @@ -196,12 +177,11 @@ class WebsockClientConnection(WebsockMessageHandler): def __init__( self, client_id: int, - thread_executor: ThreadPoolExecutor, client_state: _ClientHandleState, ) -> None: self.client_id = client_id self._state = client_state - super().__init__(thread_executor) + super().__init__() @override def get_message_buffer(self) -> AsyncMessageBuffer: @@ -239,11 +219,15 @@ def __init__( verbose: bool = True, client_api_version: Literal[0, 1] = 0, ): - super().__init__(thread_executor=ThreadPoolExecutor(max_workers=32)) + super().__init__() # Track connected clients. - self._client_connect_cb: list[Callable[[WebsockClientConnection], None]] = [] - self._client_disconnect_cb: list[Callable[[WebsockClientConnection], None]] = [] + self._client_connect_cb: list[ + Callable[[WebsockClientConnection], None | Coroutine] + ] = [] + self._client_disconnect_cb: list[ + Callable[[WebsockClientConnection], None | Coroutine] + ] = [] self._host = host self._port = port @@ -278,14 +262,15 @@ def stop(self) -> None: assert self._ws_server is not None self._ws_server.close() self._ws_server = None - self._thread_executor.shutdown(wait=True) - def on_client_connect(self, cb: Callable[[WebsockClientConnection], Any]) -> None: + def on_client_connect( + self, cb: Callable[[WebsockClientConnection], None | Coroutine] + ) -> None: """Attach a callback to run for newly connected clients.""" self._client_connect_cb.append(cb) def on_client_disconnect( - self, cb: Callable[[WebsockClientConnection], Any] + self, cb: Callable[[WebsockClientConnection], None | Coroutine] ) -> None: """Attach a callback to run when clients disconnect.""" self._client_disconnect_cb.append(cb) @@ -298,7 +283,6 @@ def get_message_buffer(self) -> AsyncMessageBuffer: def flush(self) -> None: """Flush the outgoing message buffer for broadcasted messages. Any buffered messages will immediately be sent. (by default they are windowed)""" - # TODO: we should add a flush event. self._broadcast_buffer.flush() def flush_client(self, client_id: int) -> None: @@ -346,28 +330,23 @@ async def serve(websocket: WebSocketServerProtocol) -> None: AsyncMessageBuffer(event_loop, persistent_messages=False), event_loop, ) - client_connection = WebsockClientConnection( - client_id, self._thread_executor, client_state - ) + client_connection = WebsockClientConnection(client_id, client_state) self._client_state_from_id[client_id] = client_state def handle_incoming(message: Message) -> None: - self._thread_executor.submit( - error_print_wrapper( - lambda: self._handle_incoming_message(client_id, message) - ) + event_loop.create_task( + self._handle_incoming_message(client_id, message) ) - self._thread_executor.submit( - error_print_wrapper( - lambda: client_connection._handle_incoming_message( - client_id, message - ) - ) + event_loop.create_task( + client_connection._handle_incoming_message(client_id, message) ) # New connection callbacks. for cb in self._client_connect_cb: - cb(client_connection) + if asyncio.iscoroutinefunction(cb): + await cb(client_connection) + else: + cb(client_connection) try: # For each client: infinite loop over producers (which send messages) @@ -401,7 +380,10 @@ def handle_incoming(message: Message) -> None: # Disconnection callbacks. for cb in self._client_disconnect_cb: - cb(client_connection) + if asyncio.iscoroutinefunction(cb): + await cb(client_connection) + else: + cb(client_connection) # Cleanup. self._client_state_from_id.pop(client_id) From 093e61ee695466b62f15ff6a61a50d516179e5da Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 15 Oct 2024 02:05:56 -0700 Subject: [PATCH 2/4] Fixes --- examples/15_gui_in_scene.py | 2 ++ src/viser/_gui_handles.py | 2 +- src/viser/_viser.py | 10 +++++++--- src/viser/client/src/SceneTree.tsx | 6 ++---- src/viser/infra/_infra.py | 1 - 5 files changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/15_gui_in_scene.py b/examples/15_gui_in_scene.py index 233b83c9..bbbbc9a0 100644 --- a/examples/15_gui_in_scene.py +++ b/examples/15_gui_in_scene.py @@ -51,9 +51,11 @@ def _(_): f"/frame_{i}/gui" ) with displayed_3d_container: + print("hi") go_to = client.gui.add_button("Go to") randomize_orientation = client.gui.add_button("Randomize orientation") close = client.gui.add_button("Close GUI") + print("hello") @go_to.on_click def _(_) -> None: diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index dc3fb201..a2660202 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -53,7 +53,7 @@ GuiVector2Props, GuiVector3Props, ) -from ._scene_api import NoneOrCoroutine, _encode_image_binary +from ._scene_api import _encode_image_binary from .infra import ClientId if TYPE_CHECKING: diff --git a/src/viser/_viser.py b/src/viser/_viser.py index 94fe4d18..0f5effe0 100644 --- a/src/viser/_viser.py +++ b/src/viser/_viser.py @@ -25,7 +25,7 @@ from . import transforms as tf from ._gui_api import Color, GuiApi, _make_uuid from ._notification_handle import NotificationHandle, _NotificationHandleState -from ._scene_api import NoneOrCoroutine, SceneApi, cast_vector +from ._scene_api import SceneApi, cast_vector from ._tunnel import ViserTunnel from .infra._infra import RecordHandle @@ -336,9 +336,13 @@ def __init__( self._viser_server = server # Public attributes. - self.scene: SceneApi = SceneApi(self, thread_executor=server._thread_executor) + self.scene: SceneApi = SceneApi( + self, thread_executor=server._thread_executor, event_loop=server._event_loop + ) """Handle for interacting with the 3D scene.""" - self.gui: GuiApi = GuiApi(self, thread_executor=server._thread_executor) + self.gui: GuiApi = GuiApi( + self, thread_executor=server._thread_executor, event_loop=server._event_loop + ) """Handle for interacting with the GUI.""" self.client_id: int = conn.client_id """Unique ID for this client.""" diff --git a/src/viser/client/src/SceneTree.tsx b/src/viser/client/src/SceneTree.tsx index 6daa44f2..ceccfa93 100644 --- a/src/viser/client/src/SceneTree.tsx +++ b/src/viser/client/src/SceneTree.tsx @@ -370,11 +370,9 @@ function useObjectFactory(message: SceneNodeMessage | undefined): { return { makeObject: (ref) => { // We wrap with because Html doesn't implement - // THREE.Object3D. The initial position is intended to be - // off-screen; it will be overwritten with the actual position - // after the component is mounted. + // THREE.Object3D. return ( - + Date: Tue, 15 Oct 2024 13:44:47 -0700 Subject: [PATCH 3/4] Remove test prints --- examples/15_gui_in_scene.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/15_gui_in_scene.py b/examples/15_gui_in_scene.py index bbbbc9a0..233b83c9 100644 --- a/examples/15_gui_in_scene.py +++ b/examples/15_gui_in_scene.py @@ -51,11 +51,9 @@ def _(_): f"/frame_{i}/gui" ) with displayed_3d_container: - print("hi") go_to = client.gui.add_button("Go to") randomize_orientation = client.gui.add_button("Randomize orientation") close = client.gui.add_button("Close GUI") - print("hello") @go_to.on_click def _(_) -> None: From d6d1da884b786db826ede9d1f6f8737d85439c41 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 15 Oct 2024 13:47:29 -0700 Subject: [PATCH 4/4] Comments nits --- src/viser/_gui_handles.py | 8 ++++---- src/viser/_scene_handles.py | 4 ++-- src/viser/_viser.py | 27 ++++++++++++++++++++++++--- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index a2660202..4d8bbde0 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -268,7 +268,7 @@ def on_update( - If `func` is a regular function (defined with `def`), it will be executed in a thread pool. - If `func` is an async function (defined with `async def`), it will be executed in the event loop. - Using async functions may help reduce race conditions in certain scenarios. + Using async functions can be useful for reducing race conditions. """ self._impl.update_cb.append(func) return func @@ -414,7 +414,7 @@ def on_click( - If `func` is a regular function (defined with `def`), it will be executed in a thread pool. - If `func` is an async function (defined with `async def`), it will be executed in the event loop. - Using async functions may help reduce race conditions in certain scenarios. + Using async functions can be useful for reducing race conditions. """ self._impl.update_cb.append(func) return func @@ -450,7 +450,7 @@ def on_upload( - If `func` is a regular function (defined with `def`), it will be executed in a thread pool. - If `func` is an async function (defined with `async def`), it will be executed in the event loop. - Using async functions may help reduce race conditions in certain scenarios. + Using async functions can be useful for reducing race conditions. """ self._impl.update_cb.append(func) return func @@ -474,7 +474,7 @@ def on_click( - If `func` is a regular function (defined with `def`), it will be executed in a thread pool. - If `func` is an async function (defined with `async def`), it will be executed in the event loop. - Using async functions may help reduce race conditions in certain scenarios. + Using async functions can be useful for reducing race conditions. """ self._impl.update_cb.append(func) return func diff --git a/src/viser/_scene_handles.py b/src/viser/_scene_handles.py index 47929640..6ab49f47 100644 --- a/src/viser/_scene_handles.py +++ b/src/viser/_scene_handles.py @@ -287,7 +287,7 @@ def on_click( - Standard functions (def) will be executed in a threadpool. - Async functions (async def) will be executed in the event loop. - Using async functions may help reduce race conditions in certain scenarios. + Using async functions can be useful for reducing race conditions. """ self._impl.api._websock_interface.queue_message( _messages.SetSceneNodeClickableMessage(self._impl.name, True) @@ -570,7 +570,7 @@ def on_update( - Standard functions (def) will be executed in a threadpool. - Async functions (async def) will be executed in the event loop. - Using async functions may help reduce race conditions in certain scenarios. + Using async functions can be useful for reducing race conditions. """ self._impl_aux.update_cb.append(func) return func diff --git a/src/viser/_viser.py b/src/viser/_viser.py index 0f5effe0..cb14daea 100644 --- a/src/viser/_viser.py +++ b/src/viser/_viser.py @@ -254,7 +254,14 @@ def up_direction( def on_update( self, callback: Callable[[CameraHandle], NoneOrCoroutine] ) -> Callable[[CameraHandle], NoneOrCoroutine]: - """Attach a callback to run when a new camera message is received.""" + """Attach a callback to run when a new camera message is received. + + The callback can be either a standard function or an async function: + - Standard functions (def) will be executed in a threadpool. + - Async functions (async def) will be executed in the event loop. + + Using async functions can be useful for reducing race conditions. + """ self._state.camera_cb.append(callback) return callback @@ -717,7 +724,14 @@ def get_clients(self) -> dict[int, ClientHandle]: def on_client_connect( self, cb: Callable[[ClientHandle], NoneOrCoroutine] ) -> Callable[[ClientHandle], NoneOrCoroutine]: - """Attach a callback to run for newly connected clients.""" + """Attach a callback to run for newly connected clients. + + The callback can be either a standard function or an async function: + - Standard functions (def) will be executed in a threadpool. + - Async functions (async def) will be executed in the event loop. + + Using async functions can be useful for reducing race conditions. + """ with self._client_lock: clients = self._connected_clients.copy().values() self._client_connect_cb.append(cb) @@ -741,7 +755,14 @@ def on_client_connect( def on_client_disconnect( self, cb: Callable[[ClientHandle], NoneOrCoroutine] ) -> Callable[[ClientHandle], NoneOrCoroutine]: - """Attach a callback to run when clients disconnect.""" + """Attach a callback to run when clients disconnect. + + The callback can be either a standard function or an async function: + - Standard functions (def) will be executed in a threadpool. + - Async functions (async def) will be executed in the event loop. + + Using async functions can be useful for reducing race conditions. + """ self._client_disconnect_cb.append(cb) return cb