Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concurrency refactor, async callback support #304

Merged
merged 4 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/viser/_gui_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import asyncio
import builtins
import colorsys
import dataclasses
import functools
import threading
import time
from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -183,12 +185,14 @@ def __init__(
self,
owner: ViserServer | ClientHandle, # Who do I belong to?
thread_executor: ThreadPoolExecutor,
event_loop: AbstractEventLoop,
) -> None:
from ._viser import ViserServer

self._owner = owner
"""Entity that owns this API."""
self._thread_executor = thread_executor
self._event_loop = event_loop

self._websock_interface = (
owner._websock_server
Expand Down Expand Up @@ -217,7 +221,7 @@ def __init__(
self._handle_file_transfer_part,
)

def _handle_gui_updates(
async def _handle_gui_updates(
self, client_id: ClientId, message: _messages.GuiUpdateMessage
) -> None:
"""Callback for handling GUI messages."""
Expand Down Expand Up @@ -273,7 +277,10 @@ def _handle_gui_updates(
else:
assert False

cb(GuiEvent(client, client_id, handle))
if asyncio.iscoroutinefunction(cb):
self._event_loop.create_task(cb(GuiEvent(client, client_id, handle)))
else:
self._thread_executor.submit(cb, GuiEvent(client, client_id, handle))

if handle_state.sync_cb is not None:
handle_state.sync_cb(client_id, updates_cast)
Expand Down Expand Up @@ -355,7 +362,10 @@ def _handle_file_transfer_part(
else:
assert False

cb(GuiEvent(client, client_id, handle))
if asyncio.iscoroutinefunction(cb):
self._event_loop.create_task(cb(GuiEvent(client, client_id, handle)))
else:
self._thread_executor.submit(cb, GuiEvent(client, client_id, handle))

def _get_container_uuid(self) -> str:
"""Get container ID associated with the current thread."""
Expand Down
73 changes: 52 additions & 21 deletions src/viser/_gui_handles.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import asyncio
import base64
import dataclasses
import re
import time
import uuid
import warnings
from collections.abc import Coroutine
from functools import cached_property
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -63,6 +65,7 @@

T = TypeVar("T")
TGuiHandle = TypeVar("TGuiHandle", bound="_GuiInputHandle")
NoneOrCoroutine = TypeVar("NoneOrCoroutine", None, Coroutine)


def _make_uuid() -> str:
Expand Down Expand Up @@ -96,7 +99,7 @@ class _GuiHandleState(Generic[T]):
"""Container that this GUI input was placed into."""

update_timestamp: float = 0.0
update_cb: list[Callable[[GuiEvent], None]] = dataclasses.field(
update_cb: list[Callable[[GuiEvent], None | Coroutine]] = dataclasses.field(
default_factory=list
)
"""Registered functions to call when this input is updated."""
Expand Down Expand Up @@ -220,17 +223,19 @@ def value(self, value: T | np.ndarray) -> None:

# Call update callbacks.
for cb in self._impl.update_cb:
# Pushing callbacks into separate threads helps prevent deadlocks when we
# have a lock in a callback. TODO: revisit other callbacks.
self._impl.gui_api._thread_executor.submit(
lambda: cb(
if asyncio.iscoroutinefunction(cb):
self._impl.gui_api._event_loop.create_task(
cb(GuiEvent(client_id=None, client=None, target=self))
)
else:
self._impl.gui_api._thread_executor.submit(
cb,
GuiEvent(
client_id=None,
client=None,
target=self,
)
),
)
)

@property
def update_timestamp(self) -> float:
Expand All @@ -255,11 +260,16 @@ class GuiInputHandle(_GuiInputHandle[T], Generic[T]):
"""

def on_update(
self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], Any]
) -> Callable[[GuiEvent[TGuiHandle]], None]:
"""Attach a function to call when a GUI input is updated. Callbacks stack (need
to be manually removed via :meth:`remove_update_callback()`) and will be called
from a thread."""
self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine]
) -> Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine]:
"""Attach a function to call when a GUI input is updated.

Note:
- If `func` is a regular function (defined with `def`), it will be executed in a thread pool.
- If `func` is an async function (defined with `async def`), it will be executed in the event loop.

Using async functions can be useful for reducing race conditions.
"""
self._impl.update_cb.append(func)
return func

Expand Down Expand Up @@ -396,9 +406,16 @@ class GuiButtonHandle(_GuiInputHandle[bool]):
"""

def on_click(
self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None]
) -> Callable[[GuiEvent[TGuiHandle]], None]:
"""Attach a function to call when a button is pressed. Happens in a thread."""
self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine]
) -> Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine]:
"""Attach a function to call when a button is pressed.

Note:
- If `func` is a regular function (defined with `def`), it will be executed in a thread pool.
- If `func` is an async function (defined with `async def`), it will be executed in the event loop.

Using async functions can be useful for reducing race conditions.
"""
self._impl.update_cb.append(func)
return func

Expand All @@ -425,9 +442,16 @@ class GuiUploadButtonHandle(_GuiInputHandle[UploadedFile]):
"""

def on_upload(
self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None]
) -> Callable[[GuiEvent[TGuiHandle]], None]:
"""Attach a function to call when a button is pressed. Happens in a thread."""
self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine]
) -> Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine]:
"""Attach a function to call when a file is uploaded.

Note:
- If `func` is a regular function (defined with `def`), it will be executed in a thread pool.
- If `func` is an async function (defined with `async def`), it will be executed in the event loop.

Using async functions can be useful for reducing race conditions.
"""
self._impl.update_cb.append(func)
return func

Expand All @@ -442,9 +466,16 @@ class GuiButtonGroupHandle(_GuiInputHandle[str], GuiButtonGroupProps):
"""

def on_click(
self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None]
) -> Callable[[GuiEvent[TGuiHandle]], None]:
"""Attach a function to call when a button is pressed. Happens in a thread."""
self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine]
) -> Callable[[GuiEvent[TGuiHandle]], NoneOrCoroutine]:
"""Attach a function to call when a button in the group is clicked.

Note:
- If `func` is a regular function (defined with `def`), it will be executed in a thread pool.
- If `func` is an async function (defined with `async def`), it will be executed in the event loop.

Using async functions can be useful for reducing race conditions.
"""
self._impl.update_cb.append(func)
return func

Expand Down
49 changes: 35 additions & 14 deletions src/viser/_scene_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import asyncio
import io
import time
import warnings
from collections.abc import Coroutine
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Callable, Tuple, TypeVar, Union, cast, get_args

Expand Down Expand Up @@ -39,6 +41,7 @@
SplineCubicBezierHandle,
SpotLightHandle,
TransformControlsHandle,
_ClickableSceneNodeHandle,
_TransformControlsState,
colors_to_uint8,
)
Expand All @@ -57,6 +60,8 @@
Tuple[int, int, int], Tuple[float, float, float], np.ndarray
]

NoneOrCoroutine = TypeVar("NoneOrCoroutine", None, Coroutine)


def _encode_rgb(rgb: RgbTupleOrArray) -> tuple[int, int, int]:
if isinstance(rgb, np.ndarray):
Expand Down Expand Up @@ -115,9 +120,13 @@ def __init__(
self,
owner: ViserServer | ClientHandle, # Who do I belong to?
thread_executor: ThreadPoolExecutor,
event_loop: asyncio.AbstractEventLoop,
) -> None:
from ._viser import ViserServer

self._thread_executor = thread_executor
self._event_loop = event_loop

self._owner = owner
"""Entity that owns this API."""

Expand All @@ -133,8 +142,10 @@ def __init__(
] = {}
self._handle_from_node_name: dict[str, SceneNodeHandle] = {}

self._scene_pointer_cb: Callable[[ScenePointerEvent], None] | None = None
self._scene_pointer_done_cb: Callable[[], None] = lambda: None
self._scene_pointer_cb: (
Callable[[ScenePointerEvent], None | Coroutine] | None
) = None
self._scene_pointer_done_cb: Callable[[], None | Coroutine] = lambda: None
self._scene_pointer_event_type: _messages.ScenePointerEventType | None = None

# Set up world axes handle.
Expand All @@ -159,8 +170,6 @@ def __init__(
self._handle_scene_pointer_updates,
)

self._thread_executor = thread_executor

def set_up_direction(
self,
direction: Literal["+x", "+y", "+z", "-x", "-y", "-z"]
Expand Down Expand Up @@ -1543,7 +1552,7 @@ def _get_client_handle(self, client_id: ClientId) -> ClientHandle:
assert client_id == self._owner.client_id
return self._owner

def _handle_transform_controls_updates(
async def _handle_transform_controls_updates(
self, client_id: ClientId, message: _messages.TransformControlsUpdateMessage
) -> None:
"""Callback for handling transform gizmo messages."""
Expand All @@ -1560,11 +1569,14 @@ def _handle_transform_controls_updates(

# Trigger callbacks.
for cb in handle._impl_aux.update_cb:
cb(handle)
if asyncio.iscoroutinefunction(cb):
await cb(handle)
else:
self._thread_executor.submit(cb, handle)
if handle._impl_aux.sync_cb is not None:
handle._impl_aux.sync_cb(client_id, handle)

def _handle_node_click_updates(
async def _handle_node_click_updates(
self, client_id: ClientId, message: _messages.SceneNodeClickMessage
) -> None:
"""Callback for handling click messages."""
Expand All @@ -1576,15 +1588,18 @@ def _handle_node_click_updates(
client=self._get_client_handle(client_id),
client_id=client_id,
event="click",
target=handle,
target=cast(_ClickableSceneNodeHandle, handle),
ray_origin=message.ray_origin,
ray_direction=message.ray_direction,
screen_pos=message.screen_pos,
instance_index=message.instance_index,
)
cb(event) # type: ignore
if asyncio.iscoroutinefunction(cb):
await cb(event)
else:
self._thread_executor.submit(cb, event)

def _handle_scene_pointer_updates(
async def _handle_scene_pointer_updates(
self, client_id: ClientId, message: _messages.ScenePointerMessage
):
"""Callback for handling click messages."""
Expand All @@ -1599,7 +1614,10 @@ def _handle_scene_pointer_updates(
# Call the callback if it exists, and the after-run callback.
if self._scene_pointer_cb is None:
return
self._scene_pointer_cb(event)
if asyncio.iscoroutinefunction(self._scene_pointer_cb):
await self._scene_pointer_cb(event)
else:
self._thread_executor.submit(self._scene_pointer_cb, event)

def on_pointer_event(
self, event_type: Literal["click", "rect-select"]
Expand Down Expand Up @@ -1654,8 +1672,8 @@ def decorator(

def on_pointer_callback_removed(
self,
func: Callable[[], None],
) -> Callable[[], None]:
func: Callable[[], NoneOrCoroutine],
) -> Callable[[], NoneOrCoroutine]:
"""Add a callback to run automatically when the callback for a scene
pointer event is removed. This will be triggered exactly once, either
manually (via :meth:`remove_pointer_callback()`) or automatically (if
Expand Down Expand Up @@ -1690,7 +1708,10 @@ def remove_pointer_callback(
self._owner.flush()

# Run cleanup callback.
self._scene_pointer_done_cb()
if asyncio.iscoroutinefunction(self._scene_pointer_done_cb):
self._event_loop.create_task(self._scene_pointer_done_cb())
else:
self._scene_pointer_done_cb()

# Reset the callback and event type, on the python side.
self._scene_pointer_cb = None
Expand Down
Loading
Loading