From 42f3e10339a676f80f344f9046cc9266583b58a2 Mon Sep 17 00:00:00 2001 From: brentyi Date: Mon, 23 Sep 2024 22:58:24 -0700 Subject: [PATCH 01/15] Scene handle props cleanup --- src/viser/_scene_handles.py | 96 ++++++++++++++----------------------- 1 file changed, 36 insertions(+), 60 deletions(-) diff --git a/src/viser/_scene_handles.py b/src/viser/_scene_handles.py index 70505c687..e9d93738f 100644 --- a/src/viser/_scene_handles.py +++ b/src/viser/_scene_handles.py @@ -2,11 +2,11 @@ import copy import dataclasses +from functools import cached_property from typing import ( TYPE_CHECKING, Any, Callable, - ClassVar, Dict, Generic, Literal, @@ -40,46 +40,43 @@ def colors_to_uint8(colors: onp.ndarray) -> onpt.NDArray[onp.uint8]: return colors -class _OverridablePropScenePropSettersAndGetters: - def __setattr__(self, name: str, value: Any) -> None: - handle = cast(SceneNodeHandle, self) - # Get the value of the T TypeVar. - if name in self._PropHints: - # Help the user with some casting... - hint = self._PropHints[name] - if hint == onpt.NDArray[onp.float32]: - value = value.astype(onp.float32) - elif hint == onpt.NDArray[onp.uint8] and "color" in name: - value = colors_to_uint8(value) - - setattr(handle._impl.props, name, value) - handle._impl.api._websock_interface.queue_message( - _messages.SceneNodeUpdateMessage(handle.name, {name: value}) - ) - else: - return object.__setattr__(self, name, value) - - def __getattr__(self, name: str) -> Any: - if name in self._PropHints: - return getattr(self._impl.props, name) - else: - raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" - ) - - -class _OverridableScenePropApi( - _OverridablePropScenePropSettersAndGetters if not TYPE_CHECKING else object -): +class _OverridableScenePropApi: """Mixin that allows reading/assigning properties defined in each scene node message.""" - _PropHints: ClassVar[Dict[str, type]] - - def __init__(self) -> None: - assert False - - def __init_subclass__(cls, PropClass: type): - cls._PropHints = get_type_hints(PropClass) + if not TYPE_CHECKING: + + def __setattr__(self, name: str, value: Any) -> None: + if name == "_impl": + return object.__setattr__(self, name, value) + + handle = cast(SceneNodeHandle, self) + # Get the value of the T TypeVar. + if name in self._prop_hints: + # Help the user with some casting... + hint = self._prop_hints[name] + if hint == onpt.NDArray[onp.float32]: + value = value.astype(onp.float32) + elif hint == onpt.NDArray[onp.uint8] and "color" in name: + value = colors_to_uint8(value) + + setattr(handle._impl.props, name, value) + handle._impl.api._websock_interface.queue_message( + _messages.SceneNodeUpdateMessage(handle.name, {name: value}) + ) + else: + return object.__setattr__(self, name, value) + + def __getattr__(self, name: str) -> Any: + if name in self._prop_hints: + return getattr(self._impl.props, name) + else: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + + @cached_property + def _prop_hints(self) -> Dict[str, Any]: + return get_type_hints(type(self._impl.props)) @dataclasses.dataclass(frozen=True) @@ -271,7 +268,6 @@ class CameraFrustumHandle( _ClickableSceneNodeHandle, _messages.CameraFrustumProps, _OverridableScenePropApi, - PropClass=_messages.CameraFrustumProps, ): """Handle for camera frustums.""" @@ -280,7 +276,6 @@ class DirectionalLightHandle( SceneNodeHandle, _messages.DirectionalLightProps, _OverridableScenePropApi, - PropClass=_messages.DirectionalLightProps, ): """Handle for directional lights.""" @@ -289,7 +284,6 @@ class AmbientLightHandle( SceneNodeHandle, _messages.AmbientLightProps, _OverridableScenePropApi, - PropClass=_messages.AmbientLightProps, ): """Handle for ambient lights.""" @@ -298,7 +292,6 @@ class HemisphereLightHandle( SceneNodeHandle, _messages.HemisphereLightProps, _OverridableScenePropApi, - PropClass=_messages.HemisphereLightProps, ): """Handle for hemisphere lights.""" @@ -307,7 +300,6 @@ class PointLightHandle( SceneNodeHandle, _messages.PointLightProps, _OverridableScenePropApi, - PropClass=_messages.PointLightProps, ): """Handle for point lights.""" @@ -316,7 +308,6 @@ class RectAreaLightHandle( SceneNodeHandle, _messages.RectAreaLightProps, _OverridableScenePropApi, - PropClass=_messages.RectAreaLightProps, ): """Handle for rectangular area lights.""" @@ -325,7 +316,6 @@ class SpotLightHandle( SceneNodeHandle, _messages.SpotLightProps, _OverridableScenePropApi, - PropClass=_messages.SpotLightProps, ): """Handle for spot lights.""" @@ -334,7 +324,6 @@ class PointCloudHandle( SceneNodeHandle, _messages.PointCloudProps, _OverridableScenePropApi, - PropClass=_messages.PointCloudProps, ): """Handle for point clouds. Does not support click events.""" @@ -343,7 +332,6 @@ class BatchedAxesHandle( _ClickableSceneNodeHandle, _messages.BatchedAxesProps, _OverridableScenePropApi, - PropClass=_messages.BatchedAxesProps, ): """Handle for batched coordinate frames.""" @@ -352,7 +340,6 @@ class FrameHandle( _ClickableSceneNodeHandle, _messages.FrameProps, _OverridableScenePropApi, - PropClass=_messages.FrameProps, ): """Handle for coordinate frames.""" @@ -361,7 +348,6 @@ class MeshHandle( _ClickableSceneNodeHandle, _messages.MeshProps, _OverridableScenePropApi, - PropClass=_messages.MeshProps, ): """Handle for mesh objects.""" @@ -370,7 +356,6 @@ class GaussianSplatHandle( _ClickableSceneNodeHandle, _messages.GaussianSplatsProps, _OverridableScenePropApi, - PropClass=_messages.GaussianSplatsProps, ): """Handle for Gaussian splatting objects. @@ -382,7 +367,6 @@ class MeshSkinnedHandle( _ClickableSceneNodeHandle, _messages.SkinnedMeshProps, _OverridableScenePropApi, - PropClass=_messages.SkinnedMeshProps, ): """Handle for skinned mesh objects.""" @@ -451,7 +435,6 @@ class GridHandle( SceneNodeHandle, _messages.GridProps, _OverridableScenePropApi, - PropClass=_messages.GridProps, ): """Handle for grid objects.""" @@ -460,7 +443,6 @@ class SplineCatmullRomHandle( SceneNodeHandle, _messages.CatmullRomSplineProps, _OverridableScenePropApi, - PropClass=_messages.CatmullRomSplineProps, ): """Handle for Catmull-Rom splines.""" @@ -469,7 +451,6 @@ class SplineCubicBezierHandle( SceneNodeHandle, _messages.CubicBezierSplineProps, _OverridableScenePropApi, - PropClass=_messages.CubicBezierSplineProps, ): """Handle for cubic Bezier splines.""" @@ -478,7 +459,6 @@ class GlbHandle( _ClickableSceneNodeHandle, _messages.GlbProps, _OverridableScenePropApi, - PropClass=_messages.GlbProps, ): """Handle for GLB objects.""" @@ -487,7 +467,6 @@ class ImageHandle( _ClickableSceneNodeHandle, _messages.ImageProps, _OverridableScenePropApi, - PropClass=_messages.ImageProps, ): """Handle for 2D images, rendered in 3D.""" @@ -496,7 +475,6 @@ class LabelHandle( SceneNodeHandle, _messages.LabelProps, _OverridableScenePropApi, - PropClass=_messages.LabelProps, ): """Handle for 2D label objects. Does not support click events.""" @@ -512,7 +490,6 @@ class TransformControlsHandle( _ClickableSceneNodeHandle, _messages.TransformControlsProps, _OverridableScenePropApi, - PropClass=_messages.TransformControlsProps, ): """Handle for interacting with transform control gizmos.""" @@ -536,7 +513,6 @@ class Gui3dContainerHandle( SceneNodeHandle, _messages.Gui3DProps, _OverridableScenePropApi, - PropClass=_messages.Gui3DProps, ): """Use as a context to place GUI elements into a 3D GUI container.""" From aac0788a3b6be27aed14f671d1de53b83a65ee4c Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 24 Sep 2024 00:11:47 -0700 Subject: [PATCH 02/15] Start gui refactor --- src/viser/_gui_api.py | 482 +++++----- src/viser/_gui_handles.py | 448 +++------- src/viser/_messages.py | 210 ++++- src/viser/_scene_handles.py | 108 ++- .../client/src/ControlPanel/Generated.tsx | 36 +- .../client/src/ControlPanel/GuiState.tsx | 40 +- src/viser/client/src/MessageHandler.tsx | 4 +- src/viser/client/src/WebsocketMessages.ts | 827 +++++++++++++----- src/viser/client/src/components/Button.tsx | 9 +- .../client/src/components/ButtonGroup.tsx | 10 +- src/viser/client/src/components/Checkbox.tsx | 9 +- src/viser/client/src/components/Dropdown.tsx | 10 +- src/viser/client/src/components/Folder.tsx | 8 +- src/viser/client/src/components/Markdown.tsx | 7 +- .../client/src/components/MultiSlider.tsx | 24 +- .../client/src/components/NumberInput.tsx | 11 +- .../client/src/components/PlotlyComponent.tsx | 4 +- .../client/src/components/ProgressBar.tsx | 4 +- src/viser/client/src/components/Rgb.tsx | 9 +- src/viser/client/src/components/Rgba.tsx | 9 +- src/viser/client/src/components/Slider.tsx | 11 +- src/viser/client/src/components/TabGroup.tsx | 9 +- src/viser/client/src/components/TextInput.tsx | 9 +- .../client/src/components/UploadButton.tsx | 26 +- src/viser/client/src/components/Vector2.tsx | 11 +- src/viser/client/src/components/Vector3.tsx | 11 +- src/viser/infra/_typescript_interface_gen.py | 2 + 27 files changed, 1347 insertions(+), 1001 deletions(-) diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index dea8d468e..4a2781b17 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -8,7 +8,16 @@ import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import TYPE_CHECKING, Any, Sequence, Tuple, TypeVar, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Protocol, + Sequence, + Tuple, + TypeVar, + cast, + overload, +) import numpy as onp from typing_extensions import ( @@ -44,7 +53,7 @@ ) from ._icons import svg_from_icon from ._icons_enum import IconName -from ._messages import FileTransferPartAck, GuiSliderMark +from ._messages import FileTransferPartAck, GuiBaseProps, GuiSliderMark from ._scene_api import cast_vector if TYPE_CHECKING: @@ -53,7 +62,7 @@ from ._viser import ClientHandle, ViserServer from .infra import ClientId - +GuiInputPropsType = TypeVar("GuiInputPropsType", bound=GuiBaseProps) IntOrFloat = TypeVar("IntOrFloat", int, float) TString = TypeVar("TString", bound=str) TLiteralString = TypeVar("TLiteralString", bound=LiteralString) @@ -451,21 +460,31 @@ def add_folder( """ folder_container_id = _make_unique_id() order = _apply_default_order(order) + props = _messages.GuiFolderProps( + order=order, + label=label, + expand_by_default=expand_by_default, + visible=visible, + ) self._websock_interface.queue_message( - _messages.GuiAddFolderMessage( - order=order, + _messages.GuiFolderMessage( id=folder_container_id, - label=label, container_id=self._get_container_id(), - expand_by_default=expand_by_default, - visible=visible, + props=props, ) ) return GuiFolderHandle( - _gui_api=self, - _id=folder_container_id, - _parent_container_id=self._get_container_id(), - _order=order, + _GuiHandleState( + self, + None, + props=props, + update_timestamp=0.0, + update_cb=[], + is_button=False, + sync_cb=None, + id=folder_container_id, + parent_container_id=self._get_container_id(), + ) ) def add_modal( @@ -515,14 +534,16 @@ def add_tab_group( order = _apply_default_order(order) self._websock_interface.queue_message( - _messages.GuiAddTabGroupMessage( - order=order, + _messages.GuiTabGroupMessage( id=tab_group_id, container_id=self._get_container_id(), - tab_labels=(), - visible=visible, - tab_icons_html=(), - tab_container_ids=(), + props=_messages.GuiTabGroupProps( + order=order, + tab_labels=(), + visible=visible, + tab_icons_html=(), + tab_container_ids=(), + ), ) ) return GuiTabGroupHandle( @@ -553,23 +574,31 @@ def add_markdown( Returns: A handle that can be used to interact with the GUI element. """ - handle = GuiMarkdownHandle( - _gui_api=self, - _id=_make_unique_id(), - _visible=visible, - _parent_container_id=self._get_container_id(), - _order=_apply_default_order(order), - _image_root=image_root, - _content=None, - ) - self._websock_interface.queue_message( - _messages.GuiAddMarkdownMessage( - order=handle._order, - id=handle._id, + message = _messages.GuiMarkdownMessage( + id=_make_unique_id(), + container_id=self._get_container_id(), + props=_messages.GuiMarkdownProps( + order=_apply_default_order(order), markdown="", - container_id=handle._parent_container_id, visible=visible, - ) + ), + ) + self._websock_interface.queue_message(message) + + handle = GuiMarkdownHandle( + _GuiHandleState( + self, + None, + props=message.props, + update_timestamp=0.0, + update_cb=[], + is_button=False, + sync_cb=None, + id=message.id, + parent_container_id=message.container_id, + ), + _content=content, + _image_root=image_root, ) # Logic for processing markdown, handling images, etc is all in the @@ -596,15 +625,6 @@ def add_plotly( Returns: A handle that can be used to interact with the GUI element. """ - handle = GuiPlotlyHandle( - _gui_api=self, - _id=_make_unique_id(), - _visible=visible, - _parent_container_id=self._get_container_id(), - _order=_apply_default_order(order), - _figure=None, - _aspect=None, - ) # If plotly.min.js hasn't been sent to the client yet, the client won't be able # to render the plot. Send this large file now! (~3MB) @@ -636,21 +656,36 @@ def add_plotly( # After plotly.min.js has been sent, we can send the plotly figure. # Empty string for `plotly_json_str` is a signal to the client to render nothing. - self._websock_interface.queue_message( - _messages.GuiAddPlotlyMessage( - order=handle._order, - id=handle._id, + message = _messages.GuiPlotlyMessage( + id=_make_unique_id(), + container_id=self._get_container_id(), + props=_messages.GuiPlotlyProps( + order=_apply_default_order(order), plotly_json_str="", aspect=1.0, - container_id=handle._parent_container_id, visible=visible, - ) + ), + ) + self._websock_interface.queue_message(message) + + handle = GuiPlotlyHandle( + _GuiHandleState( + self, + None, + props=message.props, + update_timestamp=0.0, + update_cb=[], + is_button=False, + sync_cb=None, + id=message.id, + parent_container_id=message.container_id, + ), + _figure=figure, ) # Set the plotly handle properties. handle.figure = figure handle.aspect = aspect - return handle def add_button( @@ -685,17 +720,19 @@ def add_button( return GuiButtonHandle( self._create_gui_input( value=False, - message=_messages.GuiAddButtonMessage( - order=order, + message=_messages.GuiButtonMessage( + value=False, id=id, - label=label, container_id=self._get_container_id(), - hint=hint, - value=False, - color=color, - icon_html=None if icon is None else svg_from_icon(icon), - disabled=disabled, - visible=visible, + props=_messages.GuiButtonProps( + order=order, + label=label, + hint=hint, + color=color, + icon_html=None if icon is None else svg_from_icon(icon), + disabled=disabled, + visible=visible, + ), ), is_button=True, )._impl @@ -735,18 +772,19 @@ def add_upload_button( return GuiUploadButtonHandle( self._create_gui_input( value=UploadedFile("", b""), - message=_messages.GuiAddUploadButtonMessage( - value=None, - disabled=disabled, - visible=visible, - order=order, + message=_messages.GuiUploadButtonMessage( id=id, - label=label, container_id=self._get_container_id(), - hint=hint, - color=color, - mime_type=mime_type, - icon_html=None if icon is None else svg_from_icon(icon), + props=_messages.GuiUploadButtonProps( + disabled=disabled, + visible=visible, + order=order, + label=label, + hint=hint, + color=color, + mime_type=mime_type, + icon_html=None if icon is None else svg_from_icon(icon), + ), ), is_button=True, )._impl @@ -807,16 +845,18 @@ def add_button_group( return GuiButtonGroupHandle( self._create_gui_input( value, - message=_messages.GuiAddButtonGroupMessage( - order=order, + message=_messages.GuiButtonGroupMessage( + value=value, id=id, - label=label, container_id=self._get_container_id(), - hint=hint, - value=value, - options=tuple(options), - disabled=disabled, - visible=visible, + props=_messages.GuiButtonGroupProps( + order=order, + label=label, + hint=hint, + options=tuple(options), + disabled=disabled, + visible=visible, + ), ), )._impl, ) @@ -849,15 +889,17 @@ def add_checkbox( order = _apply_default_order(order) return self._create_gui_input( value, - message=_messages.GuiAddCheckboxMessage( - order=order, + message=_messages.GuiCheckboxMessage( + value=value, id=id, - label=label, container_id=self._get_container_id(), - hint=hint, - value=value, - disabled=disabled, - visible=visible, + props=_messages.GuiCheckboxProps( + order=order, + label=label, + hint=hint, + disabled=disabled, + visible=visible, + ), ), ) @@ -889,15 +931,17 @@ def add_text( order = _apply_default_order(order) return self._create_gui_input( value, - message=_messages.GuiAddTextMessage( - order=order, + message=_messages.GuiTextMessage( + value=value, id=id, - label=label, container_id=self._get_container_id(), - hint=hint, - value=value, - disabled=disabled, - visible=visible, + props=_messages.GuiTextProps( + order=order, + label=label, + hint=hint, + disabled=disabled, + visible=visible, + ), ), ) @@ -953,19 +997,21 @@ def add_number( order = _apply_default_order(order) return self._create_gui_input( value, - message=_messages.GuiAddNumberMessage( - order=order, + message=_messages.GuiNumberMessage( + value=value, id=id, - label=label, container_id=self._get_container_id(), - hint=hint, - value=value, - min=min, - max=max, - precision=_compute_precision_digits(step), - step=step, - disabled=disabled, - visible=visible, + props=_messages.GuiNumberProps( + order=order, + label=label, + hint=hint, + min=min, + max=max, + precision=_compute_precision_digits(step), + step=step, + disabled=disabled, + visible=visible, + ), ), is_button=False, ) @@ -1016,19 +1062,21 @@ def add_vector2( return self._create_gui_input( value, - message=_messages.GuiAddVector2Message( - order=order, + message=_messages.GuiVector2Message( + value=value, id=id, - label=label, container_id=self._get_container_id(), - hint=hint, - value=value, - min=min, - max=max, - step=step, - precision=_compute_precision_digits(step), - disabled=disabled, - visible=visible, + props=_messages.GuiVector2Props( + order=order, + label=label, + hint=hint, + min=min, + max=max, + step=step, + precision=_compute_precision_digits(step), + disabled=disabled, + visible=visible, + ), ), ) @@ -1078,19 +1126,21 @@ def add_vector3( return self._create_gui_input( value, - message=_messages.GuiAddVector3Message( - order=order, + message=_messages.GuiVector3Message( + value=value, id=id, - label=label, container_id=self._get_container_id(), - hint=hint, - value=value, - min=min, - max=max, - step=step, - precision=_compute_precision_digits(step), - disabled=disabled, - visible=visible, + props=_messages.GuiVector3Props( + order=order, + label=label, + hint=hint, + min=min, + max=max, + step=step, + precision=_compute_precision_digits(step), + disabled=disabled, + visible=visible, + ), ), ) @@ -1151,19 +1201,20 @@ def add_dropdown( return GuiDropdownHandle( self._create_gui_input( value, - message=_messages.GuiAddDropdownMessage( - order=order, + message=_messages.GuiDropdownMessage( + value=value, id=id, - label=label, container_id=self._get_container_id(), - hint=hint, - value=value, - options=tuple(options), - disabled=disabled, - visible=visible, + props=_messages.GuiDropdownProps( + order=order, + label=label, + hint=hint, + options=tuple(options), + disabled=disabled, + visible=visible, + ), ), )._impl, - _impl_options=tuple(options), ) def add_progress_bar( @@ -1187,25 +1238,30 @@ def add_progress_bar( A handle that can be used to interact with the GUI element. """ assert value >= 0 and value <= 100 - handle = GuiProgressBarHandle( - _gui_api=self, - _id=_make_unique_id(), - _visible=visible, - _animated=animated, - _parent_container_id=self._get_container_id(), - _order=_apply_default_order(order), - _value=value, - ) - self._websock_interface.queue_message( - _messages.GuiAddProgressBarMessage( - order=handle._order, - id=handle._id, - value=value, + message = _messages.GuiProgressBarMessage( + value=value, + id=_make_unique_id(), + container_id=self._get_container_id(), + props=_messages.GuiProgressBarProps( + order=_apply_default_order(order), animated=animated, color=color, - container_id=handle._parent_container_id, visible=visible, - ) + ), + ) + self._websock_interface.queue_message(message) + handle = GuiProgressBarHandle( + _GuiHandleState( + self, + value, + props=message.props, + update_timestamp=0.0, + update_cb=[], + is_button=False, + sync_cb=None, + id=message.id, + parent_container_id=message.container_id, + ), ) return handle @@ -1264,27 +1320,29 @@ def add_slider( order = _apply_default_order(order) return self._create_gui_input( value, - message=_messages.GuiAddSliderMessage( - order=order, + message=_messages.GuiSliderMessage( + value=value, id=id, - label=label, container_id=self._get_container_id(), - hint=hint, - min=min, - max=max, - step=step, - value=value, - precision=_compute_precision_digits(step), - visible=visible, - disabled=disabled, - marks=tuple( - GuiSliderMark(value=float(x[0]), label=x[1]) - if isinstance(x, tuple) - else GuiSliderMark(value=x, label=None) - for x in marks - ) - if marks is not None - else None, + props=_messages.GuiSliderProps( + order=order, + label=label, + hint=hint, + min=min, + max=max, + step=step, + precision=_compute_precision_digits(step), + visible=visible, + disabled=disabled, + marks=tuple( + GuiSliderMark(value=float(x[0]), label=x[1]) + if isinstance(x, tuple) + else GuiSliderMark(value=x, label=None) + for x in marks + ) + if marks is not None + else None, + ), ), is_button=False, ) @@ -1345,29 +1403,31 @@ def add_multi_slider( order = _apply_default_order(order) return self._create_gui_input( value=initial_value, - message=_messages.GuiAddMultiSliderMessage( - order=order, + message=_messages.GuiMultiSliderMessage( + value=initial_value, id=id, - label=label, container_id=self._get_container_id(), - hint=hint, - min=min, - min_range=min_range, - max=max, - step=step, - value=initial_value, - visible=visible, - disabled=disabled, - fixed_endpoints=fixed_endpoints, - precision=_compute_precision_digits(step), - marks=tuple( - GuiSliderMark(value=float(x[0]), label=x[1]) - if isinstance(x, tuple) - else GuiSliderMark(value=x, label=None) - for x in marks - ) - if marks is not None - else None, + props=_messages.GuiMultiSliderProps( + order=order, + label=label, + hint=hint, + min=min, + min_range=min_range, + max=max, + step=step, + visible=visible, + disabled=disabled, + fixed_endpoints=fixed_endpoints, + precision=_compute_precision_digits(step), + marks=tuple( + GuiSliderMark(value=float(x[0]), label=x[1]) + if isinstance(x, tuple) + else GuiSliderMark(value=x, label=None) + for x in marks + ) + if marks is not None + else None, + ), ), is_button=False, ) @@ -1400,15 +1460,17 @@ def add_rgb( order = _apply_default_order(order) return self._create_gui_input( value, - message=_messages.GuiAddRgbMessage( - order=order, + message=_messages.GuiRgbMessage( + value=value, id=id, - label=label, container_id=self._get_container_id(), - hint=hint, - value=value, - disabled=disabled, - visible=visible, + props=_messages.GuiRgbProps( + order=order, + label=label, + hint=hint, + disabled=disabled, + visible=visible, + ), ), ) @@ -1439,33 +1501,39 @@ def add_rgba( order = _apply_default_order(order) return self._create_gui_input( value, - message=_messages.GuiAddRgbaMessage( - order=order, + message=_messages.GuiRgbaMessage( + value=value, id=id, - label=label, container_id=self._get_container_id(), - hint=hint, - value=value, - disabled=disabled, - visible=visible, + props=_messages.GuiRgbaProps( + order=order, + label=label, + hint=hint, + disabled=disabled, + visible=visible, + ), ), ) + class GuiMessage(Protocol[GuiInputPropsType]): + id: str + props: GuiInputPropsType + def _create_gui_input( self, value: T, - message: _messages._GuiAddInputBase, + message: GuiMessage, is_button: bool = False, ) -> GuiInputHandle[T]: """Private helper for adding a simple GUI element.""" # Send add GUI input message. + assert isinstance(message, _messages.Message) self._websock_interface.queue_message(message) # Construct handle. handle_state = _GuiHandleState( - label=message.label, - message_type=type(message), + props=message.props, gui_api=self, value=value, update_timestamp=time.time(), @@ -1473,11 +1541,7 @@ def _create_gui_input( update_cb=[], is_button=is_button, sync_cb=None, - disabled=message.disabled, - visible=message.visible, id=message.id, - order=message.order, - hint=message.hint, ) # For broadcasted GUI handles, we should synchronize all clients. diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 68a6880e4..6037f2b4b 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -6,16 +6,38 @@ import time import uuid import warnings +from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generic, + Iterable, + TypeVar, + cast, + get_type_hints, +) import imageio.v3 as iio import numpy as onp from typing_extensions import Protocol +from . import _messages from ._icons import svg_from_icon from ._icons_enum import IconName -from ._messages import GuiCloseModalMessage, GuiRemoveMessage, GuiUpdateMessage, Message +from ._messages import ( + GuiBaseProps, + GuiCloseModalMessage, + GuiDropdownProps, + GuiFolderProps, + GuiMarkdownProps, + GuiPlotlyProps, + GuiProgressBarProps, + GuiRemoveMessage, + GuiUpdateMessage, +) from ._scene_api import _encode_image_binary from .infra import ClientId @@ -45,13 +67,17 @@ class SupportsRemoveProtocol(Protocol): def remove(self) -> None: ... +class GuiPropsProtocol(Protocol): + order: float + + @dataclasses.dataclass class _GuiHandleState(Generic[T]): """Internal API for GUI elements.""" - label: str gui_api: GuiApi value: T + props: GuiPropsProtocol update_timestamp: float parent_container_id: str @@ -66,43 +92,73 @@ class _GuiHandleState(Generic[T]): sync_cb: Callable[[ClientId, dict[str, Any]], None] | None """Callback for synchronizing inputs across clients.""" - disabled: bool - visible: bool - - order: float id: str - hint: str | None - message_type: type[Message] +class _OverridableGuiPropApi: + """Mixin that allows reading/assigning properties defined in each scene node message.""" -@dataclasses.dataclass -class _GuiInputHandle(Generic[T]): + def __setattr__(self, name: str, value: Any) -> None: + if name == "_impl": + return object.__setattr__(self, name, value) + + handle = cast(_GuiInputHandle, self) + # Get the value of the T TypeVar. + if name in self._prop_hints: + setattr(handle._impl.props, name, value) + handle._impl.gui_api._websock_interface.queue_message( + _messages.GuiUpdateMessage(handle._impl.id, {name: value}) + ) + else: + return object.__setattr__(self, name, value) + + def __getattr__(self, name: str) -> Any: + if name in self._prop_hints: + return getattr(self._impl.props, name) + else: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + + @cached_property + def _prop_hints(self) -> Dict[str, Any]: + return get_type_hints(type(self._impl.props)) + + +class _GuiHandle( + Generic[T], + _OverridableGuiPropApi if not TYPE_CHECKING else object, +): # Let's shove private implementation details in here... - _impl: _GuiHandleState[T] - - # Should we use @property for get_value / set_value, set_hidden, etc? - # - # Benefits: - # @property is syntactically very nice. - # `gui.value = ...` is really tempting! - # Feels a bit more magical. - # - # Downsides: - # Consistency: not everything that can be written can be read, and not everything - # that can be read can be written. `get_`/`set_` makes this really clear. - # Clarity: some things that we read (like client mappings) are copied before - # they're returned. An attribute access obfuscates the overhead here. - # Flexibility: getter/setter types should match. https://github.com/python/mypy/issues/3004 - # Feels a bit more magical. - # - # Is this worth the tradeoff? + def __init__(self, _impl: _GuiHandleState[T]) -> None: + self._impl = _impl + parent = self._impl.gui_api._container_handle_from_id[ + self._impl.parent_container_id + ] + parent._children[self._impl.id] = self - @property - def order(self) -> float: - """Read-only order value, which dictates the position of the GUI element.""" - return self._impl.order + if isinstance(self, _GuiInputHandle): + self._impl.gui_api._gui_input_handle_from_id[self._impl.id] = self + + def remove(self) -> None: + """Permanently remove this GUI element from the visualizer.""" + self._impl.gui_api._websock_interface.queue_message( + GuiRemoveMessage(self._impl.id) + ) + parent = self._impl.gui_api._container_handle_from_id[ + self._impl.parent_container_id + ] + parent._children.pop(self._impl.id) + if isinstance(self, _GuiInputHandle): + self._impl.gui_api._gui_input_handle_from_id.pop(self._impl.id) + + +class _GuiInputHandle( + _GuiHandle[T], + Generic[T], + GuiBaseProps, +): @property def value(self) -> T: """Value of the GUI input. Synchronized automatically when assigned.""" @@ -144,56 +200,6 @@ def update_timestamp(self) -> float: """Read-only timestamp when this input was last updated.""" return self._impl.update_timestamp - @property - def disabled(self) -> bool: - """Allow/disallow user interaction with the input. Synchronized automatically - when assigned.""" - return self._impl.disabled - - @disabled.setter - def disabled(self, disabled: bool) -> None: - if disabled == self.disabled: - return - - self._impl.gui_api._websock_interface.queue_message( - GuiUpdateMessage(self._impl.id, {"disabled": disabled}) - ) - self._impl.disabled = disabled - - @property - def visible(self) -> bool: - """Temporarily show or hide this GUI element from the visualizer. Synchronized - automatically when assigned.""" - return self._impl.visible - - @visible.setter - def visible(self, visible: bool) -> None: - if visible == self.visible: - return - - self._impl.gui_api._websock_interface.queue_message( - GuiUpdateMessage(self._impl.id, {"visible": visible}) - ) - self._impl.visible = visible - - def __post_init__(self) -> None: - """We need to register ourself after construction for callbacks to work.""" - gui_api = self._impl.gui_api - - # TODO: the current way we track GUI handles and children is very manual + - # error-prone. We should revist this design. - gui_api._gui_input_handle_from_id[self._impl.id] = self - parent = gui_api._container_handle_from_id[self._impl.parent_container_id] - parent._children[self._impl.id] = self - - def remove(self) -> None: - """Permanently remove this GUI element from the visualizer.""" - gui_api = self._impl.gui_api - gui_api._websock_interface.queue_message(GuiRemoveMessage(self._impl.id)) - gui_api._gui_input_handle_from_id.pop(self._impl.id) - parent = gui_api._container_handle_from_id[self._impl.parent_container_id] - parent._children.pop(self._impl.id) - StringType = TypeVar("StringType", bound=str) @@ -202,7 +208,6 @@ def remove(self) -> None: # # We inherit from _GuiInputHandle to special-case buttons because the usage semantics # are slightly different: we have `on_click()` instead of `on_update()`. -@dataclasses.dataclass class GuiInputHandle(_GuiInputHandle[T], Generic[T]): """A handle is created for each GUI element that is added in `viser`. Handles can be used to read and write state. @@ -234,7 +239,6 @@ class GuiEvent(Generic[TGuiHandle]): """GUI element that was affected.""" -@dataclasses.dataclass class GuiButtonHandle(_GuiInputHandle[bool]): """Handle for a button input in our visualizer. @@ -258,7 +262,6 @@ class UploadedFile: """Contents of the file.""" -@dataclasses.dataclass class GuiUploadButtonHandle(_GuiInputHandle[UploadedFile]): """Handle for an upload file button in our visualizer. @@ -273,7 +276,6 @@ def on_upload( return func -@dataclasses.dataclass class GuiButtonGroupHandle(_GuiInputHandle[StringType], Generic[StringType]): """Handle for a button group input in our visualizer. @@ -292,19 +294,18 @@ def disabled(self) -> bool: return False @disabled.setter - def disabled(self, disabled: bool) -> None: + def disabled(self, disabled: bool) -> None: # type: ignore """Button groups cannot be disabled.""" assert not disabled, "Button groups cannot be disabled." -@dataclasses.dataclass -class GuiDropdownHandle(GuiInputHandle[StringType], Generic[StringType]): +class GuiDropdownHandle( + GuiInputHandle[StringType], Generic[StringType], GuiDropdownProps +): """Handle for a dropdown-style GUI input in our visualizer. Lets us get values, set values, and detect updates.""" - _impl_options: tuple[StringType, ...] - @property def options(self) -> tuple[StringType, ...]: """Options for our dropdown. Synchronized automatically when assigned. @@ -314,26 +315,29 @@ def options(self) -> tuple[StringType, ...]: inferred where possible when handles are instantiated; for the most flexibility, we can declare handles as `GuiDropdownHandle[str]`. """ - return self._impl_options + assert isinstance(self._impl.props, GuiDropdownProps) + return self._impl.props.options # type: ignore @options.setter - def options(self, options: Iterable[StringType]) -> None: - self._impl_options = tuple(options) + def options(self, options: Iterable[StringType]) -> None: # type: ignore + assert isinstance(self._impl.props, GuiDropdownProps) + options = tuple(options) + self._impl.props.options = options - need_to_overwrite_value = self.value not in self._impl_options + need_to_overwrite_value = self.value not in options if need_to_overwrite_value: self._impl.gui_api._websock_interface.queue_message( GuiUpdateMessage( self._impl.id, - {"options": self._impl_options, "value": self._impl_options[0]}, + {"options": options, "value": options}, ) ) - self._impl.value = self._impl_options[0] + self._impl.value = options[0] else: self._impl.gui_api._websock_interface.queue_message( GuiUpdateMessage( self._impl.id, - {"options": self._impl_options}, + {"options": options}, ) ) @@ -397,49 +401,42 @@ def _sync_with_client(self) -> None: ) -@dataclasses.dataclass -class GuiFolderHandle: +class GuiFolderHandle(_GuiHandle, GuiFolderProps): """Use as a context to place GUI elements into a folder.""" - _gui_api: GuiApi - _id: str # Used as container ID for children. - _order: float - _parent_container_id: str # Container ID of parent. - _container_id_restore: str | None = None - _children: dict[str, SupportsRemoveProtocol] = dataclasses.field( - default_factory=dict - ) - - @property - def order(self) -> float: - """Read-only order value, which dictates the position of the GUI element.""" - return self._order + def __init__(self, _impl: _GuiHandleState[None]) -> None: + super().__init__(_impl=_impl) + self._impl.gui_api._container_handle_from_id[self._impl.id] = self + self._children = {} + parent = self._impl.gui_api._container_handle_from_id[ + self._impl.parent_container_id + ] + parent._children[self._impl.id] = self def __enter__(self) -> GuiFolderHandle: - self._container_id_restore = self._gui_api._get_container_id() - self._gui_api._set_container_id(self._id) + self._container_id_restore = self._impl.gui_api._get_container_id() + self._impl.gui_api._set_container_id(self._impl.id) return self def __exit__(self, *args) -> None: del args assert self._container_id_restore is not None - self._gui_api._set_container_id(self._container_id_restore) + self._impl.gui_api._set_container_id(self._container_id_restore) self._container_id_restore = None - def __post_init__(self) -> None: - self._gui_api._container_handle_from_id[self._id] = self - parent = self._gui_api._container_handle_from_id[self._parent_container_id] - parent._children[self._id] = self - def remove(self) -> None: """Permanently remove this folder and all contained GUI elements from the visualizer.""" - self._gui_api._websock_interface.queue_message(GuiRemoveMessage(self._id)) + self._impl.gui_api._websock_interface.queue_message( + GuiRemoveMessage(self._impl.id) + ) for child in tuple(self._children.values()): child.remove() - parent = self._gui_api._container_handle_from_id[self._parent_container_id] - parent._children.pop(self._id) - self._gui_api._container_handle_from_id.pop(self._id) + parent = self._impl.gui_api._container_handle_from_id[ + self._impl.parent_container_id + ] + parent._children.pop(self._impl.id) + self._impl.gui_api._container_handle_from_id.pop(self._impl.id) @dataclasses.dataclass @@ -559,95 +556,17 @@ def _parse_markdown(markdown: str, image_root: Path | None) -> str: return markdown -@dataclasses.dataclass -class GuiProgressBarHandle: +class GuiProgressBarHandle(_GuiInputHandle[float], GuiProgressBarProps): """Use to remove markdown.""" - _gui_api: GuiApi - _id: str - _visible: bool - _animated: bool - _parent_container_id: str - _order: float - _value: float - - @property - def value(self) -> float: - """Current content of this progress bar element, 0 - 100. Synchronized - automatically when assigned.""" - return self._value - - @value.setter - def value(self, value: float) -> None: - assert value >= 0 and value <= 100 - self._value = value - self._gui_api._websock_interface.queue_message( - GuiUpdateMessage( - self._id, - {"value": value}, - ) - ) - - @property - def animated(self) -> bool: - """Show this progress bar as loading (animated, striped).""" - return self._animated - - @animated.setter - def animated(self, animated: bool) -> None: - self._animated = animated - self._gui_api._websock_interface.queue_message( - GuiUpdateMessage( - self._id, - {"animated": animated}, - ) - ) - - @property - def order(self) -> float: - """Read-only order value, which dictates the position of the GUI element.""" - return self._order - - @property - def visible(self) -> bool: - """Temporarily show or hide this GUI element from the visualizer. Synchronized - automatically when assigned.""" - return self._visible - - @visible.setter - def visible(self, visible: bool) -> None: - if visible == self.visible: - return - - self._gui_api._websock_interface.queue_message( - GuiUpdateMessage(self._id, {"visible": visible}) - ) - self._visible = visible - - def __post_init__(self) -> None: - """We need to register ourself after construction for callbacks to work.""" - parent = self._gui_api._container_handle_from_id[self._parent_container_id] - parent._children[self._id] = self - - def remove(self) -> None: - """Permanently remove this progress bar from the visualizer.""" - self._gui_api._websock_interface.queue_message(GuiRemoveMessage(self._id)) - parent = self._gui_api._container_handle_from_id[self._parent_container_id] - parent._children.pop(self._id) - - -@dataclasses.dataclass -class GuiMarkdownHandle: +class GuiMarkdownHandle(_GuiHandle[None], GuiMarkdownProps): """Use to remove markdown.""" - _gui_api: GuiApi - _id: str - _visible: bool - _parent_container_id: str - _order: float - _image_root: Path | None - _content: str | None + def __init__(self, _impl: _GuiHandleState, _content: str, _image_root: Path | None): + super().__init__(_impl=_impl) + self._content = _content + self._image_root = _image_root @property def content(self) -> str: @@ -658,58 +577,15 @@ def content(self) -> str: @content.setter def content(self, content: str) -> None: self._content = content - self._gui_api._websock_interface.queue_message( - GuiUpdateMessage( - self._id, - {"markdown": _parse_markdown(content, self._image_root)}, - ) - ) - - @property - def order(self) -> float: - """Read-only order value, which dictates the position of the GUI element.""" - return self._order - - @property - def visible(self) -> bool: - """Temporarily show or hide this GUI element from the visualizer. Synchronized - automatically when assigned.""" - return self._visible - - @visible.setter - def visible(self, visible: bool) -> None: - if visible == self.visible: - return - - self._gui_api._websock_interface.queue_message( - GuiUpdateMessage(self._id, {"visible": visible}) - ) - self._visible = visible - - def __post_init__(self) -> None: - """We need to register ourself after construction for callbacks to work.""" - parent = self._gui_api._container_handle_from_id[self._parent_container_id] - parent._children[self._id] = self - - def remove(self) -> None: - """Permanently remove this markdown from the visualizer.""" - self._gui_api._websock_interface.queue_message(GuiRemoveMessage(self._id)) - - parent = self._gui_api._container_handle_from_id[self._parent_container_id] - parent._children.pop(self._id) + self.markdown = _parse_markdown(content, self._image_root) -@dataclasses.dataclass -class GuiPlotlyHandle: +class GuiPlotlyHandle(_GuiHandle[None], GuiPlotlyProps): """Use to update or remove markdown elements.""" - _gui_api: GuiApi - _id: str - _visible: bool - _parent_container_id: str - _order: float - _figure: go.Figure | None - _aspect: float | None + def __init__(self, _impl: _GuiHandleState, _figure: go.Figure): + super().__init__(_impl=_impl) + self._figure = _figure @property def figure(self) -> go.Figure: @@ -723,58 +599,4 @@ def figure(self, figure: go.Figure) -> None: json_str = figure.to_json() assert isinstance(json_str, str) - - self._gui_api._websock_interface.queue_message( - GuiUpdateMessage( - self._id, - {"plotly_json_str": json_str}, - ) - ) - - @property - def aspect(self) -> float: - """Aspect ratio of the plotly figure, in the control panel.""" - assert self._aspect is not None - return self._aspect - - @aspect.setter - def aspect(self, aspect: float) -> None: - self._aspect = aspect - self._gui_api._websock_interface.queue_message( - GuiUpdateMessage( - self._id, - {"aspect": aspect}, - ) - ) - - @property - def order(self) -> float: - """Read-only order value, which dictates the position of the GUI element.""" - return self._order - - @property - def visible(self) -> bool: - """Temporarily show or hide this GUI element from the visualizer. Synchronized - automatically when assigned.""" - return self._visible - - @visible.setter - def visible(self, visible: bool) -> None: - if visible == self.visible: - return - - self._gui_api._websock_interface.queue_message( - GuiUpdateMessage(self._id, {"visible": visible}) - ) - self._visible = visible - - def __post_init__(self) -> None: - """We need to register ourself after construction for callbacks to work.""" - parent = self._gui_api._container_handle_from_id[self._parent_container_id] - parent._children[self._id] = self - - def remove(self) -> None: - """Permanently remove this figure from the visualizer.""" - self._gui_api._websock_interface.queue_message(GuiRemoveMessage(self._id)) - parent = self._gui_api._container_handle_from_id[self._parent_container_id] - parent._children.pop(self._id) + self.plotly_json_str = json_str diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 6956c63e7..9cac492db 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -65,7 +65,7 @@ def redundancy_key(self) -> str: @classmethod def __init_subclass__( - cls, tag: Literal[None, "GuiAddComponentMessage", "SceneNodeMessage"] = None + cls, tag: Literal[None, "GuiComponentMessage", "SceneNodeMessage"] = None ): """Tag will be used to create a union type in TypeScript.""" super().__init_subclass__() @@ -778,68 +778,90 @@ class ResetGuiMessage(Message): @dataclasses.dataclass -class GuiAddFolderMessage(Message, tag="GuiAddComponentMessage"): +class GuiBaseProps(Message): + """Base message type containing fields commonly used by GUI inputs.""" + order: float - id: str label: str - container_id: str - expand_by_default: bool + hint: Optional[str] visible: bool + disabled: bool @dataclasses.dataclass -class GuiAddMarkdownMessage(Message, tag="GuiAddComponentMessage"): +class GuiFolderProps: order: float + label: str + visible: bool + expand_by_default: bool + + +@dataclasses.dataclass +class GuiFolderMessage(Message, tag="GuiComponentMessage"): id: str - markdown: str container_id: str - visible: bool + props: GuiFolderProps @dataclasses.dataclass -class GuiAddProgressBarMessage(Message, tag="GuiAddComponentMessage"): +class GuiMarkdownProps: order: float + markdown: str + visible: bool + + +@dataclasses.dataclass +class GuiMarkdownMessage(Message, tag="GuiComponentMessage"): id: str - value: float + container_id: str + props: GuiMarkdownProps + + +@dataclasses.dataclass +class GuiProgressBarProps: + order: float animated: bool color: Optional[Color] - container_id: str visible: bool @dataclasses.dataclass -class GuiAddPlotlyMessage(Message, tag="GuiAddComponentMessage"): - order: float +class GuiProgressBarMessage(Message, tag="GuiComponentMessage"): + value: float id: str + container_id: str + props: GuiProgressBarProps + + +@dataclasses.dataclass +class GuiPlotlyProps: + order: float plotly_json_str: str aspect: float - container_id: str visible: bool @dataclasses.dataclass -class GuiAddTabGroupMessage(Message, tag="GuiAddComponentMessage"): - order: float +class GuiPlotlyMessage(Message, tag="GuiComponentMessage"): id: str container_id: str + props: GuiPlotlyProps + + +@dataclasses.dataclass +class GuiTabGroupProps: tab_labels: Tuple[str, ...] tab_icons_html: Tuple[Union[str, None], ...] tab_container_ids: Tuple[str, ...] + order: float visible: bool @dataclasses.dataclass -class _GuiAddInputBase(Message): - """Base message type containing fields commonly used by GUI inputs.""" - - order: float +class GuiTabGroupMessage(Message, tag="GuiComponentMessage"): id: str - label: str container_id: str - hint: Optional[str] - value: Any - visible: bool - disabled: bool + props: GuiTabGroupProps @dataclasses.dataclass @@ -855,33 +877,52 @@ class GuiCloseModalMessage(Message): @dataclasses.dataclass -class GuiAddButtonMessage(_GuiAddInputBase, tag="GuiAddComponentMessage"): - # All GUI elements currently need an `value` field. - # This makes our job on the frontend easier. - value: bool +class GuiButtonProps(GuiBaseProps): color: Optional[Color] icon_html: Optional[str] @dataclasses.dataclass -class GuiAddUploadButtonMessage(_GuiAddInputBase, tag="GuiAddComponentMessage"): +class GuiButtonMessage(Message, tag="GuiComponentMessage"): + value: bool + id: str + container_id: str + props: GuiButtonProps + + +@dataclasses.dataclass +class GuiUploadButtonProps(GuiBaseProps): color: Optional[Color] icon_html: Optional[str] mime_type: str @dataclasses.dataclass -class GuiAddSliderMessage(_GuiAddInputBase, tag="GuiAddComponentMessage"): +class GuiUploadButtonMessage(Message, tag="GuiComponentMessage"): + id: str + container_id: str + props: GuiUploadButtonProps + + +@dataclasses.dataclass +class GuiSliderProps(GuiBaseProps): min: float max: float step: Optional[float] - value: float precision: int marks: Optional[Tuple[GuiSliderMark, ...]] = None @dataclasses.dataclass -class GuiAddMultiSliderMessage(_GuiAddInputBase, tag="GuiAddComponentMessage"): +class GuiSliderMessage(Message, tag="GuiComponentMessage"): + value: float + id: str + container_id: str + props: GuiSliderProps + + +@dataclasses.dataclass +class GuiMultiSliderProps(GuiBaseProps): min: float max: float step: Optional[float] @@ -892,8 +933,15 @@ class GuiAddMultiSliderMessage(_GuiAddInputBase, tag="GuiAddComponentMessage"): @dataclasses.dataclass -class GuiAddNumberMessage(_GuiAddInputBase, tag="GuiAddComponentMessage"): - value: float +class GuiMultiSliderMessage(Message, tag="GuiComponentMessage"): + value: tuple[float, ...] + id: str + container_id: str + props: GuiMultiSliderProps + + +@dataclasses.dataclass +class GuiNumberProps(GuiBaseProps): precision: int step: float min: Optional[float] @@ -901,23 +949,54 @@ class GuiAddNumberMessage(_GuiAddInputBase, tag="GuiAddComponentMessage"): @dataclasses.dataclass -class GuiAddRgbMessage(_GuiAddInputBase, tag="GuiAddComponentMessage"): +class GuiNumberMessage(Message, tag="GuiComponentMessage"): + value: float + id: str + container_id: str + props: GuiNumberProps + + +@dataclasses.dataclass +class GuiRgbProps(GuiBaseProps): + pass + + +@dataclasses.dataclass +class GuiRgbMessage(Message, tag="GuiComponentMessage"): value: Tuple[int, int, int] + id: str + container_id: str + props: GuiRgbProps + + +@dataclasses.dataclass +class GuiRgbaProps(GuiBaseProps): + pass @dataclasses.dataclass -class GuiAddRgbaMessage(_GuiAddInputBase, tag="GuiAddComponentMessage"): +class GuiRgbaMessage(Message, tag="GuiComponentMessage"): value: Tuple[int, int, int, int] + id: str + container_id: str + props: GuiRgbaProps @dataclasses.dataclass -class GuiAddCheckboxMessage(_GuiAddInputBase, tag="GuiAddComponentMessage"): +class GuiCheckboxProps(GuiBaseProps): + pass + + +@dataclasses.dataclass +class GuiCheckboxMessage(Message, tag="GuiComponentMessage"): value: bool + id: str + container_id: str + props: GuiCheckboxProps @dataclasses.dataclass -class GuiAddVector2Message(_GuiAddInputBase, tag="GuiAddComponentMessage"): - value: Tuple[float, float] +class GuiVector2Props(GuiBaseProps): min: Optional[Tuple[float, float]] max: Optional[Tuple[float, float]] step: float @@ -925,8 +1004,15 @@ class GuiAddVector2Message(_GuiAddInputBase, tag="GuiAddComponentMessage"): @dataclasses.dataclass -class GuiAddVector3Message(_GuiAddInputBase, tag="GuiAddComponentMessage"): - value: Tuple[float, float, float] +class GuiVector2Message(Message, tag="GuiComponentMessage"): + value: Tuple[float, float] + id: str + container_id: str + props: GuiVector2Props + + +@dataclasses.dataclass +class GuiVector3Props(GuiBaseProps): min: Optional[Tuple[float, float, float]] max: Optional[Tuple[float, float, float]] step: float @@ -934,22 +1020,52 @@ class GuiAddVector3Message(_GuiAddInputBase, tag="GuiAddComponentMessage"): @dataclasses.dataclass -class GuiAddTextMessage(_GuiAddInputBase, tag="GuiAddComponentMessage"): - value: str +class GuiVector3Message(Message, tag="GuiComponentMessage"): + value: Tuple[float, float, float] + id: str + container_id: str + props: GuiVector3Props @dataclasses.dataclass -class GuiAddDropdownMessage(_GuiAddInputBase, tag="GuiAddComponentMessage"): +class GuiTextProps(GuiBaseProps): + pass + + +@dataclasses.dataclass +class GuiTextMessage(Message, tag="GuiComponentMessage"): value: str + id: str + container_id: str + props: GuiTextProps + + +@dataclasses.dataclass +class GuiDropdownProps(GuiBaseProps): options: Tuple[str, ...] @dataclasses.dataclass -class GuiAddButtonGroupMessage(_GuiAddInputBase, tag="GuiAddComponentMessage"): +class GuiDropdownMessage(Message, tag="GuiComponentMessage"): value: str + id: str + container_id: str + props: GuiDropdownProps + + +@dataclasses.dataclass +class GuiButtonGroupProps(GuiBaseProps): options: Tuple[str, ...] +@dataclasses.dataclass +class GuiButtonGroupMessage(Message, tag="GuiComponentMessage"): + value: str + id: str + container_id: str + props: GuiButtonGroupProps + + @dataclasses.dataclass class GuiRemoveMessage(Message): """Sent server->client to remove a GUI element.""" @@ -964,7 +1080,7 @@ class GuiUpdateMessage(Message): id: str updates: Annotated[ Dict[str, Any], - infra.TypeScriptAnnotationOverride("Partial"), + infra.TypeScriptAnnotationOverride("Partial"), ] """Mapping from property name to new value.""" diff --git a/src/viser/_scene_handles.py b/src/viser/_scene_handles.py index e9d93738f..f5ba05612 100644 --- a/src/viser/_scene_handles.py +++ b/src/viser/_scene_handles.py @@ -43,40 +43,38 @@ def colors_to_uint8(colors: onp.ndarray) -> onpt.NDArray[onp.uint8]: class _OverridableScenePropApi: """Mixin that allows reading/assigning properties defined in each scene node message.""" - if not TYPE_CHECKING: - - def __setattr__(self, name: str, value: Any) -> None: - if name == "_impl": - return object.__setattr__(self, name, value) - - handle = cast(SceneNodeHandle, self) - # Get the value of the T TypeVar. - if name in self._prop_hints: - # Help the user with some casting... - hint = self._prop_hints[name] - if hint == onpt.NDArray[onp.float32]: - value = value.astype(onp.float32) - elif hint == onpt.NDArray[onp.uint8] and "color" in name: - value = colors_to_uint8(value) - - setattr(handle._impl.props, name, value) - handle._impl.api._websock_interface.queue_message( - _messages.SceneNodeUpdateMessage(handle.name, {name: value}) - ) - else: - return object.__setattr__(self, name, value) - - def __getattr__(self, name: str) -> Any: - if name in self._prop_hints: - return getattr(self._impl.props, name) - else: - raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" - ) - - @cached_property - def _prop_hints(self) -> Dict[str, Any]: - return get_type_hints(type(self._impl.props)) + def __setattr__(self, name: str, value: Any) -> None: + if name == "_impl": + return object.__setattr__(self, name, value) + + handle = cast(SceneNodeHandle, self) + # Get the value of the T TypeVar. + if name in self._prop_hints: + # Help the user with some casting... + hint = self._prop_hints[name] + if hint == onpt.NDArray[onp.float32]: + value = value.astype(onp.float32) + elif hint == onpt.NDArray[onp.uint8] and "color" in name: + value = colors_to_uint8(value) + + setattr(handle._impl.props, name, value) + handle._impl.api._websock_interface.queue_message( + _messages.SceneNodeUpdateMessage(handle.name, {name: value}) + ) + else: + return object.__setattr__(self, name, value) + + def __getattr__(self, name: str) -> Any: + if name in self._prop_hints: + return getattr(self._impl.props, name) + else: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + + @cached_property + def _prop_hints(self) -> Dict[str, Any]: + return get_type_hints(type(self._impl.props)) @dataclasses.dataclass(frozen=True) @@ -267,7 +265,7 @@ def on_click( class CameraFrustumHandle( _ClickableSceneNodeHandle, _messages.CameraFrustumProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for camera frustums.""" @@ -275,7 +273,7 @@ class CameraFrustumHandle( class DirectionalLightHandle( SceneNodeHandle, _messages.DirectionalLightProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for directional lights.""" @@ -283,7 +281,7 @@ class DirectionalLightHandle( class AmbientLightHandle( SceneNodeHandle, _messages.AmbientLightProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for ambient lights.""" @@ -291,7 +289,7 @@ class AmbientLightHandle( class HemisphereLightHandle( SceneNodeHandle, _messages.HemisphereLightProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for hemisphere lights.""" @@ -299,7 +297,7 @@ class HemisphereLightHandle( class PointLightHandle( SceneNodeHandle, _messages.PointLightProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for point lights.""" @@ -307,7 +305,7 @@ class PointLightHandle( class RectAreaLightHandle( SceneNodeHandle, _messages.RectAreaLightProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for rectangular area lights.""" @@ -315,7 +313,7 @@ class RectAreaLightHandle( class SpotLightHandle( SceneNodeHandle, _messages.SpotLightProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for spot lights.""" @@ -323,7 +321,7 @@ class SpotLightHandle( class PointCloudHandle( SceneNodeHandle, _messages.PointCloudProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for point clouds. Does not support click events.""" @@ -331,7 +329,7 @@ class PointCloudHandle( class BatchedAxesHandle( _ClickableSceneNodeHandle, _messages.BatchedAxesProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for batched coordinate frames.""" @@ -339,7 +337,7 @@ class BatchedAxesHandle( class FrameHandle( _ClickableSceneNodeHandle, _messages.FrameProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for coordinate frames.""" @@ -347,7 +345,7 @@ class FrameHandle( class MeshHandle( _ClickableSceneNodeHandle, _messages.MeshProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for mesh objects.""" @@ -355,7 +353,7 @@ class MeshHandle( class GaussianSplatHandle( _ClickableSceneNodeHandle, _messages.GaussianSplatsProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for Gaussian splatting objects. @@ -366,7 +364,7 @@ class GaussianSplatHandle( class MeshSkinnedHandle( _ClickableSceneNodeHandle, _messages.SkinnedMeshProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for skinned mesh objects.""" @@ -434,7 +432,7 @@ def position(self, position: tuple[float, float, float] | onp.ndarray) -> None: class GridHandle( SceneNodeHandle, _messages.GridProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for grid objects.""" @@ -442,7 +440,7 @@ class GridHandle( class SplineCatmullRomHandle( SceneNodeHandle, _messages.CatmullRomSplineProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for Catmull-Rom splines.""" @@ -450,7 +448,7 @@ class SplineCatmullRomHandle( class SplineCubicBezierHandle( SceneNodeHandle, _messages.CubicBezierSplineProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for cubic Bezier splines.""" @@ -458,7 +456,7 @@ class SplineCubicBezierHandle( class GlbHandle( _ClickableSceneNodeHandle, _messages.GlbProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for GLB objects.""" @@ -466,7 +464,7 @@ class GlbHandle( class ImageHandle( _ClickableSceneNodeHandle, _messages.ImageProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for 2D images, rendered in 3D.""" @@ -474,7 +472,7 @@ class ImageHandle( class LabelHandle( SceneNodeHandle, _messages.LabelProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for 2D label objects. Does not support click events.""" @@ -489,7 +487,7 @@ class _TransformControlsState: class TransformControlsHandle( _ClickableSceneNodeHandle, _messages.TransformControlsProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Handle for interacting with transform control gizmos.""" @@ -512,7 +510,7 @@ def on_update( class Gui3dContainerHandle( SceneNodeHandle, _messages.Gui3DProps, - _OverridableScenePropApi, + _OverridableScenePropApi if not TYPE_CHECKING else object, ): """Use as a context to place GUI elements into a 3D GUI container.""" diff --git a/src/viser/client/src/ControlPanel/Generated.tsx b/src/viser/client/src/ControlPanel/Generated.tsx index 4f30fb410..3fe72a8b8 100644 --- a/src/viser/client/src/ControlPanel/Generated.tsx +++ b/src/viser/client/src/ControlPanel/Generated.tsx @@ -86,41 +86,41 @@ function GeneratedInput(props: { guiId: string }) { const viewer = React.useContext(ViewerContext)!; const conf = viewer.useGui((state) => state.guiConfigFromId[props.guiId]); switch (conf.type) { - case "GuiAddFolderMessage": + case "GuiFolderMessage": return ; - case "GuiAddTabGroupMessage": + case "GuiTabGroupMessage": return ; - case "GuiAddMarkdownMessage": + case "GuiMarkdownMessage": return ; - case "GuiAddPlotlyMessage": + case "GuiPlotlyMessage": return ; - case "GuiAddButtonMessage": + case "GuiButtonMessage": return ; - case "GuiAddUploadButtonMessage": + case "GuiUploadButtonMessage": return ; - case "GuiAddSliderMessage": + case "GuiSliderMessage": return ; - case "GuiAddMultiSliderMessage": + case "GuiMultiSliderMessage": return ; - case "GuiAddNumberMessage": + case "GuiNumberMessage": return ; - case "GuiAddTextMessage": + case "GuiTextMessage": return ; - case "GuiAddCheckboxMessage": + case "GuiCheckboxMessage": return ; - case "GuiAddVector2Message": + case "GuiVector2Message": return ; - case "GuiAddVector3Message": + case "GuiVector3Message": return ; - case "GuiAddDropdownMessage": + case "GuiDropdownMessage": return ; - case "GuiAddRgbMessage": + case "GuiRgbMessage": return ; - case "GuiAddRgbaMessage": + case "GuiRgbaMessage": return ; - case "GuiAddButtonGroupMessage": + case "GuiButtonGroupMessage": return ; - case "GuiAddProgressBarMessage": + case "GuiProgressBarMessage": return ; default: assertNeverType(conf); diff --git a/src/viser/client/src/ControlPanel/GuiState.tsx b/src/viser/client/src/ControlPanel/GuiState.tsx index d233145c2..e8a61a4ad 100644 --- a/src/viser/client/src/ControlPanel/GuiState.tsx +++ b/src/viser/client/src/ControlPanel/GuiState.tsx @@ -1,14 +1,16 @@ -import * as Messages from "../WebsocketMessages"; import React from "react"; import { create } from "zustand"; import { ColorTranslator } from "colortranslator"; import { immer } from "zustand/middleware/immer"; - -export type GuiConfig = Messages.GuiAddComponentMessage; +import { + GuiComponentMessage, + GuiModalMessage, + ThemeConfigurationMessage, +} from "../WebsocketMessages"; interface GuiState { - theme: Messages.ThemeConfigurationMessage; + theme: ThemeConfigurationMessage; label: string; server: string; shareUrl: string | null; @@ -17,9 +19,9 @@ interface GuiState { guiIdSetFromContainerId: { [containerId: string]: { [id: string]: true } | undefined; }; - modals: Messages.GuiModalMessage[]; + modals: GuiModalMessage[]; guiOrderFromId: { [id: string]: number }; - guiConfigFromId: { [id: string]: GuiConfig }; + guiConfigFromId: { [id: string]: GuiComponentMessage }; uploadsInProgress: { [id: string]: { notificationId: string; @@ -31,10 +33,10 @@ interface GuiState { } interface GuiActions { - setTheme: (theme: Messages.ThemeConfigurationMessage) => void; + setTheme: (theme: ThemeConfigurationMessage) => void; setShareUrl: (share_url: string | null) => void; - addGui: (config: GuiConfig) => void; - addModal: (config: Messages.GuiModalMessage) => void; + addGui: (config: GuiComponentMessage) => void; + addModal: (config: GuiModalMessage) => void; removeModal: (id: string) => void; updateGuiProps: (id: string, updates: { [key: string]: any }) => void; removeGui: (id: string) => void; @@ -98,7 +100,7 @@ export function useGuiState(initialServer: string) { }), addGui: (guiConfig) => set((state) => { - state.guiOrderFromId[guiConfig.id] = guiConfig.order; + state.guiOrderFromId[guiConfig.id] = guiConfig.props.order; state.guiConfigFromId[guiConfig.id] = guiConfig; if (!(guiConfig.container_id in state.guiIdSetFromContainerId)) { state.guiIdSetFromContainerId[guiConfig.container_id] = {}; @@ -160,18 +162,18 @@ export function useGuiState(initialServer: string) { return; } - // Double-check that key exists. - Object.keys(updates).forEach((key) => { - if (!(key in config)) + // Iterate over key/value pairs. + for (const [key, value] of Object.entries(updates)) { + if (key === "value") { + state.guiConfigFromId[id].value = value; + } else if (!(key in config.props)) { console.error( `Tried to update nonexistent property '${key}' of GUI element ${id}!`, ); - }); - - state.guiConfigFromId[id] = { - ...config, - ...updates, - } as GuiConfig; + } else { + state.guiConfigFromId[id].props[key] = value; + } + } }); }, })), diff --git a/src/viser/client/src/MessageHandler.tsx b/src/viser/client/src/MessageHandler.tsx index 6196f053b..e6cd4a89d 100644 --- a/src/viser/client/src/MessageHandler.tsx +++ b/src/viser/client/src/MessageHandler.tsx @@ -10,7 +10,7 @@ import { FileTransferStart, Message, SceneNodeMessage, - isGuiAddComponentMessage, + isGuiComponentMessage, isSceneNodeMessage, } from "./WebsocketMessages"; import { isTexture } from "./WebsocketFunctions"; @@ -69,7 +69,7 @@ function useMessageHandler() { // Return message handler. return (message: Message) => { - if (isGuiAddComponentMessage(message)) { + if (isGuiComponentMessage(message)) { addGui(message); return; } diff --git a/src/viser/client/src/WebsocketMessages.ts b/src/viser/client/src/WebsocketMessages.ts index a78ba2053..498b70a5a 100644 --- a/src/viser/client/src/WebsocketMessages.ts +++ b/src/viser/client/src/WebsocketMessages.ts @@ -543,114 +543,27 @@ export interface ResetSceneMessage { export interface ResetGuiMessage { type: "ResetGuiMessage"; } -/** GuiAddFolderMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', expand_by_default: 'bool', visible: 'bool') - * - * (automatically generated) - */ -export interface GuiAddFolderMessage { - type: "GuiAddFolderMessage"; - order: number; - id: string; - label: string; - container_id: string; - expand_by_default: boolean; - visible: boolean; -} -/** GuiAddMarkdownMessage(order: 'float', id: 'str', markdown: 'str', container_id: 'str', visible: 'bool') - * - * (automatically generated) - */ -export interface GuiAddMarkdownMessage { - type: "GuiAddMarkdownMessage"; - order: number; - id: string; - markdown: string; - container_id: string; - visible: boolean; -} -/** GuiAddProgressBarMessage(order: 'float', id: 'str', value: 'float', animated: 'bool', color: 'Optional[Color]', container_id: 'str', visible: 'bool') - * - * (automatically generated) - */ -export interface GuiAddProgressBarMessage { - type: "GuiAddProgressBarMessage"; - order: number; - id: string; - value: number; - animated: boolean; - color: - | "dark" - | "gray" - | "red" - | "pink" - | "grape" - | "violet" - | "indigo" - | "blue" - | "cyan" - | "green" - | "lime" - | "yellow" - | "orange" - | "teal" - | null; - container_id: string; - visible: boolean; -} -/** GuiAddPlotlyMessage(order: 'float', id: 'str', plotly_json_str: 'str', aspect: 'float', container_id: 'str', visible: 'bool') - * - * (automatically generated) - */ -export interface GuiAddPlotlyMessage { - type: "GuiAddPlotlyMessage"; - order: number; - id: string; - plotly_json_str: string; - aspect: number; - container_id: string; - visible: boolean; -} -/** GuiAddTabGroupMessage(order: 'float', id: 'str', container_id: 'str', tab_labels: 'Tuple[str, ...]', tab_icons_html: 'Tuple[Union[str, None], ...]', tab_container_ids: 'Tuple[str, ...]', visible: 'bool') - * - * (automatically generated) - */ -export interface GuiAddTabGroupMessage { - type: "GuiAddTabGroupMessage"; - order: number; - id: string; - container_id: string; - tab_labels: string[]; - tab_icons_html: (string | null)[]; - tab_container_ids: string[]; - visible: boolean; -} /** Base message type containing fields commonly used by GUI inputs. * * (automatically generated) */ -export interface _GuiAddInputBase { - type: "_GuiAddInputBase"; +export interface GuiBaseProps { + type: "GuiBaseProps"; order: number; - id: string; label: string; - container_id: string; hint: string | null; - value: any; visible: boolean; disabled: boolean; } -/** GuiAddButtonMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'bool', visible: 'bool', disabled: 'bool', color: 'Optional[Color]', icon_html: 'Optional[str]') +/** GuiButtonProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', color: 'Optional[Color]', icon_html: 'Optional[str]') * * (automatically generated) */ -export interface GuiAddButtonMessage { - type: "GuiAddButtonMessage"; +export interface GuiButtonProps { + type: "GuiButtonProps"; order: number; - id: string; label: string; - container_id: string; hint: string | null; - value: boolean; visible: boolean; disabled: boolean; color: @@ -671,18 +584,15 @@ export interface GuiAddButtonMessage { | null; icon_html: string | null; } -/** GuiAddUploadButtonMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Any', visible: 'bool', disabled: 'bool', color: 'Optional[Color]', icon_html: 'Optional[str]', mime_type: 'str') +/** GuiUploadButtonProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', color: 'Optional[Color]', icon_html: 'Optional[str]', mime_type: 'str') * * (automatically generated) */ -export interface GuiAddUploadButtonMessage { - type: "GuiAddUploadButtonMessage"; +export interface GuiUploadButtonProps { + type: "GuiUploadButtonProps"; order: number; - id: string; label: string; - container_id: string; hint: string | null; - value: any; visible: boolean; disabled: boolean; color: @@ -704,18 +614,15 @@ export interface GuiAddUploadButtonMessage { icon_html: string | null; mime_type: string; } -/** GuiAddSliderMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'float', visible: 'bool', disabled: 'bool', min: 'float', max: 'float', step: 'Optional[float]', precision: 'int', marks: 'Optional[Tuple[GuiSliderMark, ...]]' = None) +/** GuiSliderProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', min: 'float', max: 'float', step: 'Optional[float]', precision: 'int', marks: 'Optional[Tuple[GuiSliderMark, ...]]' = None) * * (automatically generated) */ -export interface GuiAddSliderMessage { - type: "GuiAddSliderMessage"; +export interface GuiSliderProps { + type: "GuiSliderProps"; order: number; - id: string; label: string; - container_id: string; hint: string | null; - value: number; visible: boolean; disabled: boolean; min: number; @@ -724,18 +631,15 @@ export interface GuiAddSliderMessage { precision: number; marks: { value: number; label: string | null }[] | null; } -/** GuiAddMultiSliderMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Any', visible: 'bool', disabled: 'bool', min: 'float', max: 'float', step: 'Optional[float]', min_range: 'Optional[float]', precision: 'int', fixed_endpoints: 'bool' = False, marks: 'Optional[Tuple[GuiSliderMark, ...]]' = None) +/** GuiMultiSliderProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', min: 'float', max: 'float', step: 'Optional[float]', min_range: 'Optional[float]', precision: 'int', fixed_endpoints: 'bool' = False, marks: 'Optional[Tuple[GuiSliderMark, ...]]' = None) * * (automatically generated) */ -export interface GuiAddMultiSliderMessage { - type: "GuiAddMultiSliderMessage"; +export interface GuiMultiSliderProps { + type: "GuiMultiSliderProps"; order: number; - id: string; label: string; - container_id: string; hint: string | null; - value: any; visible: boolean; disabled: boolean; min: number; @@ -746,18 +650,15 @@ export interface GuiAddMultiSliderMessage { fixed_endpoints: boolean; marks: { value: number; label: string | null }[] | null; } -/** GuiAddNumberMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'float', visible: 'bool', disabled: 'bool', precision: 'int', step: 'float', min: 'Optional[float]', max: 'Optional[float]') +/** GuiNumberProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', precision: 'int', step: 'float', min: 'Optional[float]', max: 'Optional[float]') * * (automatically generated) */ -export interface GuiAddNumberMessage { - type: "GuiAddNumberMessage"; +export interface GuiNumberProps { + type: "GuiNumberProps"; order: number; - id: string; label: string; - container_id: string; hint: string | null; - value: number; visible: boolean; disabled: boolean; precision: number; @@ -765,63 +666,51 @@ export interface GuiAddNumberMessage { min: number | null; max: number | null; } -/** GuiAddRgbMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Tuple[int, int, int]', visible: 'bool', disabled: 'bool') +/** GuiRgbProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool') * * (automatically generated) */ -export interface GuiAddRgbMessage { - type: "GuiAddRgbMessage"; +export interface GuiRgbProps { + type: "GuiRgbProps"; order: number; - id: string; label: string; - container_id: string; hint: string | null; - value: [number, number, number]; visible: boolean; disabled: boolean; } -/** GuiAddRgbaMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Tuple[int, int, int, int]', visible: 'bool', disabled: 'bool') +/** GuiRgbaProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool') * * (automatically generated) */ -export interface GuiAddRgbaMessage { - type: "GuiAddRgbaMessage"; +export interface GuiRgbaProps { + type: "GuiRgbaProps"; order: number; - id: string; label: string; - container_id: string; hint: string | null; - value: [number, number, number, number]; visible: boolean; disabled: boolean; } -/** GuiAddCheckboxMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'bool', visible: 'bool', disabled: 'bool') +/** GuiCheckboxProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool') * * (automatically generated) */ -export interface GuiAddCheckboxMessage { - type: "GuiAddCheckboxMessage"; +export interface GuiCheckboxProps { + type: "GuiCheckboxProps"; order: number; - id: string; label: string; - container_id: string; hint: string | null; - value: boolean; visible: boolean; disabled: boolean; } -/** GuiAddVector2Message(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Tuple[float, float]', visible: 'bool', disabled: 'bool', min: 'Optional[Tuple[float, float]]', max: 'Optional[Tuple[float, float]]', step: 'float', precision: 'int') +/** GuiVector2Props(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', min: 'Optional[Tuple[float, float]]', max: 'Optional[Tuple[float, float]]', step: 'float', precision: 'int') * * (automatically generated) */ -export interface GuiAddVector2Message { - type: "GuiAddVector2Message"; +export interface GuiVector2Props { + type: "GuiVector2Props"; order: number; - id: string; label: string; - container_id: string; hint: string | null; - value: [number, number]; visible: boolean; disabled: boolean; min: [number, number] | null; @@ -829,18 +718,15 @@ export interface GuiAddVector2Message { step: number; precision: number; } -/** GuiAddVector3Message(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Tuple[float, float, float]', visible: 'bool', disabled: 'bool', min: 'Optional[Tuple[float, float, float]]', max: 'Optional[Tuple[float, float, float]]', step: 'float', precision: 'int') +/** GuiVector3Props(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', min: 'Optional[Tuple[float, float, float]]', max: 'Optional[Tuple[float, float, float]]', step: 'float', precision: 'int') * * (automatically generated) */ -export interface GuiAddVector3Message { - type: "GuiAddVector3Message"; +export interface GuiVector3Props { + type: "GuiVector3Props"; order: number; - id: string; label: string; - container_id: string; hint: string | null; - value: [number, number, number]; visible: boolean; disabled: boolean; min: [number, number, number] | null; @@ -848,53 +734,242 @@ export interface GuiAddVector3Message { step: number; precision: number; } -/** GuiAddTextMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'str', visible: 'bool', disabled: 'bool') +/** GuiTextProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool') * * (automatically generated) */ -export interface GuiAddTextMessage { - type: "GuiAddTextMessage"; +export interface GuiTextProps { + type: "GuiTextProps"; order: number; - id: string; label: string; - container_id: string; hint: string | null; - value: string; visible: boolean; disabled: boolean; } -/** GuiAddDropdownMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'str', visible: 'bool', disabled: 'bool', options: 'Tuple[str, ...]') +/** GuiDropdownProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', options: 'Tuple[str, ...]') * * (automatically generated) */ -export interface GuiAddDropdownMessage { - type: "GuiAddDropdownMessage"; +export interface GuiDropdownProps { + type: "GuiDropdownProps"; order: number; - id: string; label: string; - container_id: string; hint: string | null; - value: string; visible: boolean; disabled: boolean; options: string[]; } -/** GuiAddButtonGroupMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'str', visible: 'bool', disabled: 'bool', options: 'Tuple[str, ...]') +/** Handle for a dropdown-style GUI input in our visualizer. + * + * Lets us get values, set values, and detect updates. * * (automatically generated) */ -export interface GuiAddButtonGroupMessage { - type: "GuiAddButtonGroupMessage"; +export interface GuiDropdownHandle { + type: "GuiDropdownHandle"; + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + options: string[]; +} +/** GuiButtonGroupProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', options: 'Tuple[str, ...]') + * + * (automatically generated) + */ +export interface GuiButtonGroupProps { + type: "GuiButtonGroupProps"; + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + options: string[]; +} +export interface _GuiInputHandle { + type: "_GuiInputHandle"; + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; +} +/** A handle is created for each GUI element that is added in `viser`. + * Handles can be used to read and write state. + * + * When a GUI element is added via :attr:`ViserServer.gui`, state is + * synchronized between all connected clients. When a GUI element is added via + * :attr:`ClientHandle.gui`, state is local to a specific client. + * + * + * (automatically generated) + */ +export interface GuiInputHandle { + type: "GuiInputHandle"; + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; +} +/** Handle for a dropdown-style GUI input in our visualizer. + * + * Lets us get values, set values, and detect updates. + * + * (automatically generated) + */ +export interface GuiDropdownHandle { + type: "GuiDropdownHandle"; order: number; - id: string; label: string; - container_id: string; hint: string | null; - value: string; visible: boolean; disabled: boolean; options: string[]; } +/** Handle for a button input in our visualizer. + * + * Lets us detect clicks. + * + * (automatically generated) + */ +export interface GuiButtonHandle { + type: "GuiButtonHandle"; + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; +} +/** Handle for an upload file button in our visualizer. + * + * The `.value` attribute will be updated with the contents of uploaded files. + * + * + * (automatically generated) + */ +export interface GuiUploadButtonHandle { + type: "GuiUploadButtonHandle"; + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; +} +/** Handle for a button group input in our visualizer. + * + * Lets us detect clicks. + * + * (automatically generated) + */ +export interface GuiButtonGroupHandle { + type: "GuiButtonGroupHandle"; + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; +} +/** Use to remove markdown. + * + * (automatically generated) + */ +export interface GuiProgressBarHandle { + type: "GuiProgressBarHandle"; + order: number; + visible: boolean; + label: string; + hint: string | null; + disabled: boolean; +} +/** GuiFolderMessage(id: 'str', container_id: 'str', props: 'GuiFolderProps') + * + * (automatically generated) + */ +export interface GuiFolderMessage { + type: "GuiFolderMessage"; + id: string; + container_id: string; + props: { + order: number; + label: string; + visible: boolean; + expand_by_default: boolean; + }; +} +/** GuiMarkdownMessage(id: 'str', container_id: 'str', props: 'GuiMarkdownProps') + * + * (automatically generated) + */ +export interface GuiMarkdownMessage { + type: "GuiMarkdownMessage"; + id: string; + container_id: string; + props: { order: number; markdown: string; visible: boolean }; +} +/** GuiProgressBarMessage(value: 'float', id: 'str', container_id: 'str', props: 'GuiProgressBarProps') + * + * (automatically generated) + */ +export interface GuiProgressBarMessage { + type: "GuiProgressBarMessage"; + value: number; + id: string; + container_id: string; + props: { + order: number; + animated: boolean; + color: + | "dark" + | "gray" + | "red" + | "pink" + | "grape" + | "violet" + | "indigo" + | "blue" + | "cyan" + | "green" + | "lime" + | "yellow" + | "orange" + | "teal" + | null; + visible: boolean; + }; +} +/** GuiPlotlyMessage(id: 'str', container_id: 'str', props: 'GuiPlotlyProps') + * + * (automatically generated) + */ +export interface GuiPlotlyMessage { + type: "GuiPlotlyMessage"; + id: string; + container_id: string; + props: { + order: number; + plotly_json_str: string; + aspect: number; + visible: boolean; + }; +} +/** GuiTabGroupMessage(id: 'str', container_id: 'str', props: 'GuiTabGroupProps') + * + * (automatically generated) + */ +export interface GuiTabGroupMessage { + type: "GuiTabGroupMessage"; + id: string; + container_id: string; + props: { + tab_labels: string[]; + tab_icons_html: (string | null)[]; + tab_container_ids: string[]; + order: number; + visible: boolean; + }; +} /** GuiModalMessage(order: 'float', id: 'str', title: 'str') * * (automatically generated) @@ -913,6 +988,287 @@ export interface GuiCloseModalMessage { type: "GuiCloseModalMessage"; id: string; } +/** GuiButtonMessage(value: 'bool', id: 'str', container_id: 'str', props: 'GuiButtonProps') + * + * (automatically generated) + */ +export interface GuiButtonMessage { + type: "GuiButtonMessage"; + value: boolean; + id: string; + container_id: string; + props: { + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + color: + | "dark" + | "gray" + | "red" + | "pink" + | "grape" + | "violet" + | "indigo" + | "blue" + | "cyan" + | "green" + | "lime" + | "yellow" + | "orange" + | "teal" + | null; + icon_html: string | null; + }; +} +/** GuiUploadButtonMessage(id: 'str', container_id: 'str', props: 'GuiUploadButtonProps') + * + * (automatically generated) + */ +export interface GuiUploadButtonMessage { + type: "GuiUploadButtonMessage"; + id: string; + container_id: string; + props: { + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + color: + | "dark" + | "gray" + | "red" + | "pink" + | "grape" + | "violet" + | "indigo" + | "blue" + | "cyan" + | "green" + | "lime" + | "yellow" + | "orange" + | "teal" + | null; + icon_html: string | null; + mime_type: string; + }; +} +/** GuiSliderMessage(value: 'float', id: 'str', container_id: 'str', props: 'GuiSliderProps') + * + * (automatically generated) + */ +export interface GuiSliderMessage { + type: "GuiSliderMessage"; + value: number; + id: string; + container_id: string; + props: { + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + min: number; + max: number; + step: number | null; + precision: number; + marks: { value: number; label: string | null }[] | null; + }; +} +/** GuiMultiSliderMessage(value: 'tuple[float, ...]', id: 'str', container_id: 'str', props: 'GuiMultiSliderProps') + * + * (automatically generated) + */ +export interface GuiMultiSliderMessage { + type: "GuiMultiSliderMessage"; + value: number[]; + id: string; + container_id: string; + props: { + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + min: number; + max: number; + step: number | null; + min_range: number | null; + precision: number; + fixed_endpoints: boolean; + marks: { value: number; label: string | null }[] | null; + }; +} +/** GuiNumberMessage(value: 'float', id: 'str', container_id: 'str', props: 'GuiNumberProps') + * + * (automatically generated) + */ +export interface GuiNumberMessage { + type: "GuiNumberMessage"; + value: number; + id: string; + container_id: string; + props: { + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + precision: number; + step: number; + min: number | null; + max: number | null; + }; +} +/** GuiRgbMessage(value: 'Tuple[int, int, int]', id: 'str', container_id: 'str', props: 'GuiRgbProps') + * + * (automatically generated) + */ +export interface GuiRgbMessage { + type: "GuiRgbMessage"; + value: [number, number, number]; + id: string; + container_id: string; + props: { + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + }; +} +/** GuiRgbaMessage(value: 'Tuple[int, int, int, int]', id: 'str', container_id: 'str', props: 'GuiRgbaProps') + * + * (automatically generated) + */ +export interface GuiRgbaMessage { + type: "GuiRgbaMessage"; + value: [number, number, number, number]; + id: string; + container_id: string; + props: { + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + }; +} +/** GuiCheckboxMessage(value: 'bool', id: 'str', container_id: 'str', props: 'GuiCheckboxProps') + * + * (automatically generated) + */ +export interface GuiCheckboxMessage { + type: "GuiCheckboxMessage"; + value: boolean; + id: string; + container_id: string; + props: { + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + }; +} +/** GuiVector2Message(value: 'Tuple[float, float]', id: 'str', container_id: 'str', props: 'GuiVector2Props') + * + * (automatically generated) + */ +export interface GuiVector2Message { + type: "GuiVector2Message"; + value: [number, number]; + id: string; + container_id: string; + props: { + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + min: [number, number] | null; + max: [number, number] | null; + step: number; + precision: number; + }; +} +/** GuiVector3Message(value: 'Tuple[float, float, float]', id: 'str', container_id: 'str', props: 'GuiVector3Props') + * + * (automatically generated) + */ +export interface GuiVector3Message { + type: "GuiVector3Message"; + value: [number, number, number]; + id: string; + container_id: string; + props: { + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + min: [number, number, number] | null; + max: [number, number, number] | null; + step: number; + precision: number; + }; +} +/** GuiTextMessage(value: 'str', id: 'str', container_id: 'str', props: 'GuiTextProps') + * + * (automatically generated) + */ +export interface GuiTextMessage { + type: "GuiTextMessage"; + value: string; + id: string; + container_id: string; + props: { + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + }; +} +/** GuiDropdownMessage(value: 'str', id: 'str', container_id: 'str', props: 'GuiDropdownProps') + * + * (automatically generated) + */ +export interface GuiDropdownMessage { + type: "GuiDropdownMessage"; + value: string; + id: string; + container_id: string; + props: { + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + options: string[]; + }; +} +/** GuiButtonGroupMessage(value: 'str', id: 'str', container_id: 'str', props: 'GuiButtonGroupProps') + * + * (automatically generated) + */ +export interface GuiButtonGroupMessage { + type: "GuiButtonGroupMessage"; + value: string; + id: string; + container_id: string; + props: { + order: number; + label: string; + hint: string | null; + visible: boolean; + disabled: boolean; + options: string[]; + }; +} /** Sent server->client to remove a GUI element. * * (automatically generated) @@ -928,7 +1284,7 @@ export interface GuiRemoveMessage { export interface GuiUpdateMessage { type: "GuiUpdateMessage"; id: string; - updates: Partial; + updates: Partial; } /** Sent client<->server when any property of a scene node is changed. * @@ -1149,27 +1505,48 @@ export type Message = | SceneNodeClickMessage | ResetSceneMessage | ResetGuiMessage - | GuiAddFolderMessage - | GuiAddMarkdownMessage - | GuiAddProgressBarMessage - | GuiAddPlotlyMessage - | GuiAddTabGroupMessage - | _GuiAddInputBase - | GuiAddButtonMessage - | GuiAddUploadButtonMessage - | GuiAddSliderMessage - | GuiAddMultiSliderMessage - | GuiAddNumberMessage - | GuiAddRgbMessage - | GuiAddRgbaMessage - | GuiAddCheckboxMessage - | GuiAddVector2Message - | GuiAddVector3Message - | GuiAddTextMessage - | GuiAddDropdownMessage - | GuiAddButtonGroupMessage + | GuiBaseProps + | GuiButtonProps + | GuiUploadButtonProps + | GuiSliderProps + | GuiMultiSliderProps + | GuiNumberProps + | GuiRgbProps + | GuiRgbaProps + | GuiCheckboxProps + | GuiVector2Props + | GuiVector3Props + | GuiTextProps + | GuiDropdownProps + | GuiDropdownHandle + | GuiButtonGroupProps + | _GuiInputHandle + | GuiInputHandle + | GuiDropdownHandle + | GuiButtonHandle + | GuiUploadButtonHandle + | GuiButtonGroupHandle + | GuiProgressBarHandle + | GuiFolderMessage + | GuiMarkdownMessage + | GuiProgressBarMessage + | GuiPlotlyMessage + | GuiTabGroupMessage | GuiModalMessage | GuiCloseModalMessage + | GuiButtonMessage + | GuiUploadButtonMessage + | GuiSliderMessage + | GuiMultiSliderMessage + | GuiNumberMessage + | GuiRgbMessage + | GuiRgbaMessage + | GuiCheckboxMessage + | GuiVector2Message + | GuiVector3Message + | GuiTextMessage + | GuiDropdownMessage + | GuiButtonGroupMessage | GuiRemoveMessage | GuiUpdateMessage | SceneNodeUpdateMessage @@ -1208,25 +1585,25 @@ export type SceneNodeMessage = | CatmullRomSplineMessage | CubicBezierSplineMessage | GaussianSplatsMessage; -export type GuiAddComponentMessage = - | GuiAddFolderMessage - | GuiAddMarkdownMessage - | GuiAddProgressBarMessage - | GuiAddPlotlyMessage - | GuiAddTabGroupMessage - | GuiAddButtonMessage - | GuiAddUploadButtonMessage - | GuiAddSliderMessage - | GuiAddMultiSliderMessage - | GuiAddNumberMessage - | GuiAddRgbMessage - | GuiAddRgbaMessage - | GuiAddCheckboxMessage - | GuiAddVector2Message - | GuiAddVector3Message - | GuiAddTextMessage - | GuiAddDropdownMessage - | GuiAddButtonGroupMessage; +export type GuiComponentMessage = + | GuiFolderMessage + | GuiMarkdownMessage + | GuiProgressBarMessage + | GuiPlotlyMessage + | GuiTabGroupMessage + | GuiButtonMessage + | GuiUploadButtonMessage + | GuiSliderMessage + | GuiMultiSliderMessage + | GuiNumberMessage + | GuiRgbMessage + | GuiRgbaMessage + | GuiCheckboxMessage + | GuiVector2Message + | GuiVector3Message + | GuiTextMessage + | GuiDropdownMessage + | GuiButtonGroupMessage; const typeSetSceneNodeMessage = new Set([ "CameraFrustumMessage", "GlbMessage", @@ -1255,28 +1632,28 @@ export function isSceneNodeMessage( ): message is SceneNodeMessage { return typeSetSceneNodeMessage.has(message.type); } -const typeSetGuiAddComponentMessage = new Set([ - "GuiAddFolderMessage", - "GuiAddMarkdownMessage", - "GuiAddProgressBarMessage", - "GuiAddPlotlyMessage", - "GuiAddTabGroupMessage", - "GuiAddButtonMessage", - "GuiAddUploadButtonMessage", - "GuiAddSliderMessage", - "GuiAddMultiSliderMessage", - "GuiAddNumberMessage", - "GuiAddRgbMessage", - "GuiAddRgbaMessage", - "GuiAddCheckboxMessage", - "GuiAddVector2Message", - "GuiAddVector3Message", - "GuiAddTextMessage", - "GuiAddDropdownMessage", - "GuiAddButtonGroupMessage", +const typeSetGuiComponentMessage = new Set([ + "GuiFolderMessage", + "GuiMarkdownMessage", + "GuiProgressBarMessage", + "GuiPlotlyMessage", + "GuiTabGroupMessage", + "GuiButtonMessage", + "GuiUploadButtonMessage", + "GuiSliderMessage", + "GuiMultiSliderMessage", + "GuiNumberMessage", + "GuiRgbMessage", + "GuiRgbaMessage", + "GuiCheckboxMessage", + "GuiVector2Message", + "GuiVector3Message", + "GuiTextMessage", + "GuiDropdownMessage", + "GuiButtonGroupMessage", ]); -export function isGuiAddComponentMessage( +export function isGuiComponentMessage( message: Message, -): message is GuiAddComponentMessage { - return typeSetGuiAddComponentMessage.has(message.type); +): message is GuiComponentMessage { + return typeSetGuiComponentMessage.has(message.type); } diff --git a/src/viser/client/src/components/Button.tsx b/src/viser/client/src/components/Button.tsx index 88bc7de60..3e37aa120 100644 --- a/src/viser/client/src/components/Button.tsx +++ b/src/viser/client/src/components/Button.tsx @@ -1,4 +1,4 @@ -import { GuiAddButtonMessage } from "../WebsocketMessages"; +import { GuiButtonMessage } from "../WebsocketMessages"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { Box } from "@mantine/core"; @@ -8,11 +8,8 @@ import { htmlIconWrapper } from "./ComponentStyles.css"; export default function ButtonComponent({ id, - visible, - disabled, - label, - ...otherProps -}: GuiAddButtonMessage) { + props: { visible, disabled, label, ...otherProps }, +}: GuiButtonMessage) { const { messageSender } = React.useContext(GuiComponentContext)!; const { color, icon_html } = otherProps; if (!(visible ?? true)) return <>; diff --git a/src/viser/client/src/components/ButtonGroup.tsx b/src/viser/client/src/components/ButtonGroup.tsx index e605b3d13..434d44e32 100644 --- a/src/viser/client/src/components/ButtonGroup.tsx +++ b/src/viser/client/src/components/ButtonGroup.tsx @@ -1,17 +1,13 @@ import * as React from "react"; import { Button, Flex } from "@mantine/core"; import { ViserInputComponent } from "./common"; -import { GuiAddButtonGroupMessage } from "../WebsocketMessages"; +import { GuiButtonGroupMessage } from "../WebsocketMessages"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; export default function ButtonGroupComponent({ id, - hint, - label, - visible, - disabled, - options, -}: GuiAddButtonGroupMessage) { + props: { hint, label, visible, disabled, options }, +}: GuiButtonGroupMessage) { const { messageSender } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( diff --git a/src/viser/client/src/components/Checkbox.tsx b/src/viser/client/src/components/Checkbox.tsx index 57be70c8c..805943af6 100644 --- a/src/viser/client/src/components/Checkbox.tsx +++ b/src/viser/client/src/components/Checkbox.tsx @@ -1,17 +1,14 @@ import * as React from "react"; import { ViserInputComponent } from "./common"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; -import { GuiAddCheckboxMessage } from "../WebsocketMessages"; +import { GuiCheckboxMessage } from "../WebsocketMessages"; import { Box, Checkbox, Tooltip } from "@mantine/core"; export default function CheckboxComponent({ id, - disabled, - visible, - hint, - label, value, -}: GuiAddCheckboxMessage) { + props: { disabled, visible, hint, label }, +}: GuiCheckboxMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; let input = ( diff --git a/src/viser/client/src/components/Dropdown.tsx b/src/viser/client/src/components/Dropdown.tsx index 9a027ac6d..b0d9faf41 100644 --- a/src/viser/client/src/components/Dropdown.tsx +++ b/src/viser/client/src/components/Dropdown.tsx @@ -1,18 +1,14 @@ import * as React from "react"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { ViserInputComponent } from "./common"; -import { GuiAddDropdownMessage } from "../WebsocketMessages"; +import { GuiDropdownMessage } from "../WebsocketMessages"; import { Select } from "@mantine/core"; export default function DropdownComponent({ id, - hint, - label, value, - disabled, - visible, - options, -}: GuiAddDropdownMessage) { + props: { hint, label, disabled, visible, options }, +}: GuiDropdownMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( diff --git a/src/viser/client/src/components/Folder.tsx b/src/viser/client/src/components/Folder.tsx index 69887b880..25811181b 100644 --- a/src/viser/client/src/components/Folder.tsx +++ b/src/viser/client/src/components/Folder.tsx @@ -1,6 +1,6 @@ import * as React from "react"; import { useDisclosure } from "@mantine/hooks"; -import { GuiAddFolderMessage } from "../WebsocketMessages"; +import { GuiFolderMessage } from "../WebsocketMessages"; import { IconChevronDown, IconChevronUp } from "@tabler/icons-react"; import { Box, Collapse, Paper } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; @@ -9,10 +9,8 @@ import { folderLabel, folderToggleIcon, folderWrapper } from "./Folder.css"; export default function FolderComponent({ id, - label, - visible, - expand_by_default, -}: GuiAddFolderMessage) { + props: { label, visible, expand_by_default }, +}: GuiFolderMessage) { const viewer = React.useContext(ViewerContext)!; const [opened, { toggle }] = useDisclosure(expand_by_default); const guiIdSet = viewer.useGui((state) => state.guiIdSetFromContainerId[id]); diff --git a/src/viser/client/src/components/Markdown.tsx b/src/viser/client/src/components/Markdown.tsx index bdb4d2739..c48063f9d 100644 --- a/src/viser/client/src/components/Markdown.tsx +++ b/src/viser/client/src/components/Markdown.tsx @@ -1,12 +1,11 @@ import { Box, Text } from "@mantine/core"; import Markdown from "../Markdown"; import { ErrorBoundary } from "react-error-boundary"; -import { GuiAddMarkdownMessage } from "../WebsocketMessages"; +import { GuiMarkdownMessage } from "../WebsocketMessages"; export default function MarkdownComponent({ - visible, - markdown, -}: GuiAddMarkdownMessage) { + props: { visible, markdown }, +}: GuiMarkdownMessage) { if (!visible) return <>; return ( diff --git a/src/viser/client/src/components/MultiSlider.tsx b/src/viser/client/src/components/MultiSlider.tsx index 7f4cf1f6f..e90ee9012 100644 --- a/src/viser/client/src/components/MultiSlider.tsx +++ b/src/viser/client/src/components/MultiSlider.tsx @@ -1,5 +1,5 @@ import React from "react"; -import { GuiAddMultiSliderMessage } from "../WebsocketMessages"; +import { GuiMultiSliderMessage } from "../WebsocketMessages"; import { Box, useMantineColorScheme } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { ViserInputComponent } from "./common"; @@ -8,18 +8,24 @@ import { sliderDefaultMarks } from "./ComponentStyles.css"; export default function MultiSliderComponent({ id, - label, - hint, - visible, - disabled, value, - ...otherProps -}: GuiAddMultiSliderMessage) { + props: { + label, + hint, + visible, + disabled, + min, + max, + precision, + step, + marks, + fixed_endpoints, + min_range, + }, +}: GuiMultiSliderMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; const updateValue = (value: number[]) => setValue(id, value); - const { min, max, precision, step, marks, fixed_endpoints, min_range } = - otherProps; const colorScheme = useMantineColorScheme().colorScheme; const input = ( diff --git a/src/viser/client/src/components/NumberInput.tsx b/src/viser/client/src/components/NumberInput.tsx index ecf65d744..d2b312665 100644 --- a/src/viser/client/src/components/NumberInput.tsx +++ b/src/viser/client/src/components/NumberInput.tsx @@ -1,20 +1,15 @@ import * as React from "react"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; -import { GuiAddNumberMessage } from "../WebsocketMessages"; +import { GuiNumberMessage } from "../WebsocketMessages"; import { ViserInputComponent } from "./common"; import { NumberInput } from "@mantine/core"; export default function NumberInputComponent({ - visible, id, - label, - hint, value, - disabled, - ...otherProps -}: GuiAddNumberMessage) { + props: { visible, label, hint, disabled, precision, min, max, step }, +}: GuiNumberMessage) { const { setValue } = React.useContext(GuiComponentContext)!; - const { precision, min, max, step } = otherProps; if (!visible) return <>; return ( diff --git a/src/viser/client/src/components/PlotlyComponent.tsx b/src/viser/client/src/components/PlotlyComponent.tsx index 39cec0071..455486ac1 100644 --- a/src/viser/client/src/components/PlotlyComponent.tsx +++ b/src/viser/client/src/components/PlotlyComponent.tsx @@ -1,5 +1,5 @@ import React from "react"; -import { GuiAddPlotlyMessage } from "../WebsocketMessages"; +import { GuiPlotlyMessage } from "../WebsocketMessages"; import { useDisclosure } from "@mantine/hooks"; import { Modal, Box, Paper, Tooltip } from "@mantine/core"; import { useElementSize } from "@mantine/hooks"; @@ -85,7 +85,7 @@ export default function PlotlyComponent({ visible, plotly_json_str, aspect, -}: GuiAddPlotlyMessage) { +}: GuiPlotlyMessage) { if (!visible) return <>; // Create a modal with the plot, and a button to open it. diff --git a/src/viser/client/src/components/ProgressBar.tsx b/src/viser/client/src/components/ProgressBar.tsx index 0b27db29d..944f4ecf6 100644 --- a/src/viser/client/src/components/ProgressBar.tsx +++ b/src/viser/client/src/components/ProgressBar.tsx @@ -1,12 +1,12 @@ import { Box, Progress } from "@mantine/core"; -import { GuiAddProgressBarMessage } from "../WebsocketMessages"; +import { GuiProgressBarMessage } from "../WebsocketMessages"; export default function ProgressBarComponent({ visible, color, value, animated, -}: GuiAddProgressBarMessage) { +}: GuiProgressBarMessage) { if (!visible) return <>; return ( diff --git a/src/viser/client/src/components/Rgb.tsx b/src/viser/client/src/components/Rgb.tsx index 53d2760c5..5f1ff8f0c 100644 --- a/src/viser/client/src/components/Rgb.tsx +++ b/src/viser/client/src/components/Rgb.tsx @@ -3,16 +3,13 @@ import { ColorInput } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { rgbToHex, hexToRgb } from "./utils"; import { ViserInputComponent } from "./common"; -import { GuiAddRgbMessage } from "../WebsocketMessages"; +import { GuiRgbMessage } from "../WebsocketMessages"; export default function RgbComponent({ id, - label, - hint, value, - disabled, - visible, -}: GuiAddRgbMessage) { + props: { label, hint, disabled, visible }, +}: GuiRgbMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( diff --git a/src/viser/client/src/components/Rgba.tsx b/src/viser/client/src/components/Rgba.tsx index 2c5e073ce..59922426c 100644 --- a/src/viser/client/src/components/Rgba.tsx +++ b/src/viser/client/src/components/Rgba.tsx @@ -3,16 +3,13 @@ import { ColorInput } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { rgbaToHex, hexToRgba } from "./utils"; import { ViserInputComponent } from "./common"; -import { GuiAddRgbaMessage } from "../WebsocketMessages"; +import { GuiRgbaMessage } from "../WebsocketMessages"; export default function RgbaComponent({ id, - label, - hint, value, - disabled, - visible, -}: GuiAddRgbaMessage) { + props: { label, hint, disabled, visible }, +}: GuiRgbaMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( diff --git a/src/viser/client/src/components/Slider.tsx b/src/viser/client/src/components/Slider.tsx index 648073a9d..b6901dbd1 100644 --- a/src/viser/client/src/components/Slider.tsx +++ b/src/viser/client/src/components/Slider.tsx @@ -1,5 +1,5 @@ import React from "react"; -import { GuiAddSliderMessage } from "../WebsocketMessages"; +import { GuiSliderMessage } from "../WebsocketMessages"; import { Slider, Flex, @@ -12,17 +12,12 @@ import { sliderDefaultMarks } from "./ComponentStyles.css"; export default function SliderComponent({ id, - label, - hint, - visible, - disabled, value, - ...otherProps -}: GuiAddSliderMessage) { + props: { label, hint, visible, disabled, min, max, precision, step, marks }, +}: GuiSliderMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; const updateValue = (value: number) => setValue(id, value); - const { min, max, precision, step, marks } = otherProps; const colorScheme = useMantineColorScheme().colorScheme; const input = ( diff --git a/src/viser/client/src/components/TabGroup.tsx b/src/viser/client/src/components/TabGroup.tsx index 8abc1d500..863546883 100644 --- a/src/viser/client/src/components/TabGroup.tsx +++ b/src/viser/client/src/components/TabGroup.tsx @@ -1,15 +1,12 @@ import * as React from "react"; -import { GuiAddTabGroupMessage } from "../WebsocketMessages"; +import { GuiTabGroupMessage } from "../WebsocketMessages"; import { Tabs } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { htmlIconWrapper } from "./ComponentStyles.css"; export default function TabGroupComponent({ - tab_labels, - tab_icons_html, - tab_container_ids, - visible, -}: GuiAddTabGroupMessage) { + props: { tab_labels, tab_icons_html, tab_container_ids, visible }, +}: GuiTabGroupMessage) { const { GuiContainer } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( diff --git a/src/viser/client/src/components/TextInput.tsx b/src/viser/client/src/components/TextInput.tsx index 1d4002b02..b18477ff7 100644 --- a/src/viser/client/src/components/TextInput.tsx +++ b/src/viser/client/src/components/TextInput.tsx @@ -1,11 +1,14 @@ import * as React from "react"; import { TextInput } from "@mantine/core"; import { ViserInputComponent } from "./common"; -import { GuiAddTextMessage } from "../WebsocketMessages"; +import { GuiTextMessage } from "../WebsocketMessages"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; -export default function TextInputComponent(props: GuiAddTextMessage) { - const { id, hint, label, value, disabled, visible } = props; +export default function TextInputComponent({ + id, + value, + props: { hint, label, disabled, visible }, +}: GuiTextMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( diff --git a/src/viser/client/src/components/UploadButton.tsx b/src/viser/client/src/components/UploadButton.tsx index 23da1f010..34feb3390 100644 --- a/src/viser/client/src/components/UploadButton.tsx +++ b/src/viser/client/src/components/UploadButton.tsx @@ -1,4 +1,4 @@ -import { GuiAddUploadButtonMessage } from "../WebsocketMessages"; +import { GuiUploadButtonMessage } from "../WebsocketMessages"; import { v4 as uuid } from "uuid"; import { Box, Progress } from "@mantine/core"; @@ -9,24 +9,26 @@ import { IconCheck } from "@tabler/icons-react"; import { notifications } from "@mantine/notifications"; import { htmlIconWrapper } from "./ComponentStyles.css"; -export default function UploadButtonComponent(conf: GuiAddUploadButtonMessage) { +export default function UploadButtonComponent({ + id, + props: { disabled, mime_type, color, icon_html, label }, +}: GuiUploadButtonMessage) { // Handle GUI input types. const viewer = useContext(ViewerContext)!; const fileUploadRef = React.useRef(null); const { isUploading, upload } = useFileUpload({ viewer, - componentId: conf.id, + componentId: id, }); - const disabled = conf.disabled || isUploading; return ( { const input = e.target as HTMLInputElement; @@ -35,27 +37,27 @@ export default function UploadButtonComponent(conf: GuiAddUploadButtonMessage) { }} /> ); diff --git a/src/viser/client/src/components/Vector2.tsx b/src/viser/client/src/components/Vector2.tsx index 089d0dc4d..9971d3bc2 100644 --- a/src/viser/client/src/components/Vector2.tsx +++ b/src/viser/client/src/components/Vector2.tsx @@ -1,18 +1,13 @@ import * as React from "react"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; -import { GuiAddVector2Message } from "../WebsocketMessages"; +import { GuiVector2Message } from "../WebsocketMessages"; import { VectorInput, ViserInputComponent } from "./common"; export default function Vector2Component({ id, - hint, - label, - visible, - disabled, value, - ...otherProps -}: GuiAddVector2Message) { - const { min, max, step, precision } = otherProps; + props: { hint, label, visible, disabled, min, max, step, precision }, +}: GuiVector2Message) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( diff --git a/src/viser/client/src/components/Vector3.tsx b/src/viser/client/src/components/Vector3.tsx index 4b20219f8..e297ca097 100644 --- a/src/viser/client/src/components/Vector3.tsx +++ b/src/viser/client/src/components/Vector3.tsx @@ -1,18 +1,13 @@ import * as React from "react"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; -import { GuiAddVector3Message } from "../WebsocketMessages"; +import { GuiVector3Message } from "../WebsocketMessages"; import { VectorInput, ViserInputComponent } from "./common"; export default function Vector3Component({ id, - hint, - label, - visible, - disabled, value, - ...otherProps -}: GuiAddVector3Message) { - const { min, max, step, precision } = otherProps; + props: { hint, label, visible, disabled, min, max, step, precision }, +}: GuiVector3Message) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; return ( diff --git a/src/viser/infra/_typescript_interface_gen.py b/src/viser/infra/_typescript_interface_gen.py index 728075f91..1a86d72b5 100644 --- a/src/viser/infra/_typescript_interface_gen.py +++ b/src/viser/infra/_typescript_interface_gen.py @@ -85,6 +85,8 @@ def _get_ts_type(typ: Type[Any]) -> str: return "{ [key: " + _get_ts_type(args[0]) + "]: " + _get_ts_type(args[1]) + " }" elif is_typeddict(typ) or dataclasses.is_dataclass(typ): hints = get_type_hints(typ) + if dataclasses.is_dataclass(typ): + hints = {field.name: hints[field.name] for field in dataclasses.fields(typ)} optional_keys = getattr(typ, "__optional_keys__", []) def fmt(key): From 0ad346f28e65b56b354fb59b18cff501578f2a85 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 24 Sep 2024 00:38:02 -0700 Subject: [PATCH 03/15] Private props --- src/viser/_gui_api.py | 8 ++++---- src/viser/_gui_handles.py | 4 ++-- src/viser/_messages.py | 8 ++++---- src/viser/client/src/WebsocketMessages.ts | 16 ++++++++-------- src/viser/client/src/components/Markdown.tsx | 2 +- src/viser/client/src/components/MultiSlider.tsx | 2 +- .../client/src/components/PlotlyComponent.tsx | 4 +--- src/viser/client/src/components/Slider.tsx | 12 +++++++++++- 8 files changed, 32 insertions(+), 24 deletions(-) diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index 4a2781b17..a8b380998 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -579,7 +579,7 @@ def add_markdown( container_id=self._get_container_id(), props=_messages.GuiMarkdownProps( order=_apply_default_order(order), - markdown="", + _markdown="", visible=visible, ), ) @@ -661,7 +661,7 @@ def add_plotly( container_id=self._get_container_id(), props=_messages.GuiPlotlyProps( order=_apply_default_order(order), - plotly_json_str="", + _plotly_json_str="", aspect=1.0, visible=visible, ), @@ -1334,7 +1334,7 @@ def add_slider( precision=_compute_precision_digits(step), visible=visible, disabled=disabled, - marks=tuple( + _marks=tuple( GuiSliderMark(value=float(x[0]), label=x[1]) if isinstance(x, tuple) else GuiSliderMark(value=x, label=None) @@ -1419,7 +1419,7 @@ def add_multi_slider( disabled=disabled, fixed_endpoints=fixed_endpoints, precision=_compute_precision_digits(step), - marks=tuple( + _marks=tuple( GuiSliderMark(value=float(x[0]), label=x[1]) if isinstance(x, tuple) else GuiSliderMark(value=x, label=None) diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 6037f2b4b..87cf7bf90 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -577,7 +577,7 @@ def content(self) -> str: @content.setter def content(self, content: str) -> None: self._content = content - self.markdown = _parse_markdown(content, self._image_root) + self._markdown = _parse_markdown(content, self._image_root) class GuiPlotlyHandle(_GuiHandle[None], GuiPlotlyProps): @@ -599,4 +599,4 @@ def figure(self, figure: go.Figure) -> None: json_str = figure.to_json() assert isinstance(json_str, str) - self.plotly_json_str = json_str + self._plotly_json_str = json_str diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 9cac492db..c8d0eaae1 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -806,7 +806,7 @@ class GuiFolderMessage(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiMarkdownProps: order: float - markdown: str + _markdown: str visible: bool @@ -836,7 +836,7 @@ class GuiProgressBarMessage(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiPlotlyProps: order: float - plotly_json_str: str + _plotly_json_str: str aspect: float visible: bool @@ -910,7 +910,7 @@ class GuiSliderProps(GuiBaseProps): max: float step: Optional[float] precision: int - marks: Optional[Tuple[GuiSliderMark, ...]] = None + _marks: Optional[Tuple[GuiSliderMark, ...]] = None @dataclasses.dataclass @@ -929,7 +929,7 @@ class GuiMultiSliderProps(GuiBaseProps): min_range: Optional[float] precision: int fixed_endpoints: bool = False - marks: Optional[Tuple[GuiSliderMark, ...]] = None + _marks: Optional[Tuple[GuiSliderMark, ...]] = None @dataclasses.dataclass diff --git a/src/viser/client/src/WebsocketMessages.ts b/src/viser/client/src/WebsocketMessages.ts index 498b70a5a..2a164b996 100644 --- a/src/viser/client/src/WebsocketMessages.ts +++ b/src/viser/client/src/WebsocketMessages.ts @@ -614,7 +614,7 @@ export interface GuiUploadButtonProps { icon_html: string | null; mime_type: string; } -/** GuiSliderProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', min: 'float', max: 'float', step: 'Optional[float]', precision: 'int', marks: 'Optional[Tuple[GuiSliderMark, ...]]' = None) +/** GuiSliderProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', min: 'float', max: 'float', step: 'Optional[float]', precision: 'int', _marks: 'Optional[Tuple[GuiSliderMark, ...]]' = None) * * (automatically generated) */ @@ -629,9 +629,9 @@ export interface GuiSliderProps { max: number; step: number | null; precision: number; - marks: { value: number; label: string | null }[] | null; + _marks: { value: number; label: string | null }[] | null; } -/** GuiMultiSliderProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', min: 'float', max: 'float', step: 'Optional[float]', min_range: 'Optional[float]', precision: 'int', fixed_endpoints: 'bool' = False, marks: 'Optional[Tuple[GuiSliderMark, ...]]' = None) +/** GuiMultiSliderProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', min: 'float', max: 'float', step: 'Optional[float]', min_range: 'Optional[float]', precision: 'int', fixed_endpoints: 'bool' = False, _marks: 'Optional[Tuple[GuiSliderMark, ...]]' = None) * * (automatically generated) */ @@ -648,7 +648,7 @@ export interface GuiMultiSliderProps { min_range: number | null; precision: number; fixed_endpoints: boolean; - marks: { value: number; label: string | null }[] | null; + _marks: { value: number; label: string | null }[] | null; } /** GuiNumberProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', precision: 'int', step: 'float', min: 'Optional[float]', max: 'Optional[float]') * @@ -906,7 +906,7 @@ export interface GuiMarkdownMessage { type: "GuiMarkdownMessage"; id: string; container_id: string; - props: { order: number; markdown: string; visible: boolean }; + props: { order: number; _markdown: string; visible: boolean }; } /** GuiProgressBarMessage(value: 'float', id: 'str', container_id: 'str', props: 'GuiProgressBarProps') * @@ -949,7 +949,7 @@ export interface GuiPlotlyMessage { container_id: string; props: { order: number; - plotly_json_str: string; + _plotly_json_str: string; aspect: number; visible: boolean; }; @@ -1075,7 +1075,7 @@ export interface GuiSliderMessage { max: number; step: number | null; precision: number; - marks: { value: number; label: string | null }[] | null; + _marks: { value: number; label: string | null }[] | null; }; } /** GuiMultiSliderMessage(value: 'tuple[float, ...]', id: 'str', container_id: 'str', props: 'GuiMultiSliderProps') @@ -1099,7 +1099,7 @@ export interface GuiMultiSliderMessage { min_range: number | null; precision: number; fixed_endpoints: boolean; - marks: { value: number; label: string | null }[] | null; + _marks: { value: number; label: string | null }[] | null; }; } /** GuiNumberMessage(value: 'float', id: 'str', container_id: 'str', props: 'GuiNumberProps') diff --git a/src/viser/client/src/components/Markdown.tsx b/src/viser/client/src/components/Markdown.tsx index c48063f9d..8097a3dc3 100644 --- a/src/viser/client/src/components/Markdown.tsx +++ b/src/viser/client/src/components/Markdown.tsx @@ -4,7 +4,7 @@ import { ErrorBoundary } from "react-error-boundary"; import { GuiMarkdownMessage } from "../WebsocketMessages"; export default function MarkdownComponent({ - props: { visible, markdown }, + props: { visible, _markdown: markdown }, }: GuiMarkdownMessage) { if (!visible) return <>; return ( diff --git a/src/viser/client/src/components/MultiSlider.tsx b/src/viser/client/src/components/MultiSlider.tsx index e90ee9012..b472c6cbf 100644 --- a/src/viser/client/src/components/MultiSlider.tsx +++ b/src/viser/client/src/components/MultiSlider.tsx @@ -18,7 +18,7 @@ export default function MultiSliderComponent({ max, precision, step, - marks, + _marks: marks, fixed_endpoints, min_range, }, diff --git a/src/viser/client/src/components/PlotlyComponent.tsx b/src/viser/client/src/components/PlotlyComponent.tsx index 455486ac1..983a4b230 100644 --- a/src/viser/client/src/components/PlotlyComponent.tsx +++ b/src/viser/client/src/components/PlotlyComponent.tsx @@ -82,9 +82,7 @@ const PlotWithAspect = React.memo(function PlotWithAspect({ }); export default function PlotlyComponent({ - visible, - plotly_json_str, - aspect, + props: { visible, _plotly_json_str: plotly_json_str, aspect }, }: GuiPlotlyMessage) { if (!visible) return <>; diff --git a/src/viser/client/src/components/Slider.tsx b/src/viser/client/src/components/Slider.tsx index b6901dbd1..c730e97cc 100644 --- a/src/viser/client/src/components/Slider.tsx +++ b/src/viser/client/src/components/Slider.tsx @@ -13,7 +13,17 @@ import { sliderDefaultMarks } from "./ComponentStyles.css"; export default function SliderComponent({ id, value, - props: { label, hint, visible, disabled, min, max, precision, step, marks }, + props: { + label, + hint, + visible, + disabled, + min, + max, + precision, + step, + _marks: marks, + }, }: GuiSliderMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>; From a2c3ccf044fff5190654f8c165811428585d61e9 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 24 Sep 2024 00:50:54 -0700 Subject: [PATCH 04/15] More handles --- docs/source/gui_handles.md | 20 ++ src/viser/__init__.py | 9 + src/viser/_gui_api.py | 374 ++++++++++++++++++++----------------- src/viser/_gui_handles.py | 40 ++++ 4 files changed, 268 insertions(+), 175 deletions(-) diff --git a/docs/source/gui_handles.md b/docs/source/gui_handles.md index 8d5b3669d..16d33f2c0 100644 --- a/docs/source/gui_handles.md +++ b/docs/source/gui_handles.md @@ -20,4 +20,24 @@ .. autoclass:: viser.GuiTabHandle() +.. autoclass:: viser.GuiCheckboxHandle() + +.. autoclass:: viser.GuiEvent() + +.. autoclass:: viser.GuiMultiSliderHandle() + +.. autoclass:: viser.GuiNumberHandle() + +.. autoclass:: viser.GuiRgbaHandle() + +.. autoclass:: viser.GuiRgbHandle() + +.. autoclass:: viser.GuiSliderHandle() + +.. autoclass:: viser.GuiTextHandle() + +.. autoclass:: viser.GuiVector2Handle() + +.. autoclass:: viser.GuiVector3Handle() + diff --git a/src/viser/__init__.py b/src/viser/__init__.py index 99a6f38a3..168da8e99 100644 --- a/src/viser/__init__.py +++ b/src/viser/__init__.py @@ -1,14 +1,23 @@ from ._gui_api import GuiApi as GuiApi from ._gui_handles import GuiButtonGroupHandle as GuiButtonGroupHandle from ._gui_handles import GuiButtonHandle as GuiButtonHandle +from ._gui_handles import GuiCheckboxHandle as GuiCheckboxHandle from ._gui_handles import GuiDropdownHandle as GuiDropdownHandle from ._gui_handles import GuiEvent as GuiEvent from ._gui_handles import GuiFolderHandle as GuiFolderHandle from ._gui_handles import GuiInputHandle as GuiInputHandle from ._gui_handles import GuiMarkdownHandle as GuiMarkdownHandle +from ._gui_handles import GuiMultiSliderHandle as GuiMultiSliderHandle +from ._gui_handles import GuiNumberHandle as GuiNumberHandle from ._gui_handles import GuiPlotlyHandle as GuiPlotlyHandle +from ._gui_handles import GuiRgbaHandle as GuiRgbaHandle +from ._gui_handles import GuiRgbHandle as GuiRgbHandle +from ._gui_handles import GuiSliderHandle as GuiSliderHandle from ._gui_handles import GuiTabGroupHandle as GuiTabGroupHandle from ._gui_handles import GuiTabHandle as GuiTabHandle +from ._gui_handles import GuiTextHandle as GuiTextHandle +from ._gui_handles import GuiVector2Handle as GuiVector2Handle +from ._gui_handles import GuiVector3Handle as GuiVector3Handle from ._icons_enum import Icon as Icon from ._icons_enum import IconName as IconName from ._notification_handle import NotificationHandle as NotificationHandle diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index a8b380998..712f6ba4f 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -34,17 +34,25 @@ from ._gui_handles import ( GuiButtonGroupHandle, GuiButtonHandle, + GuiCheckboxHandle, GuiContainerProtocol, GuiDropdownHandle, GuiEvent, GuiFolderHandle, - GuiInputHandle, GuiMarkdownHandle, GuiModalHandle, + GuiMultiSliderHandle, + GuiNumberHandle, GuiPlotlyHandle, GuiProgressBarHandle, + GuiRgbaHandle, + GuiRgbHandle, + GuiSliderHandle, GuiTabGroupHandle, + GuiTextHandle, GuiUploadButtonHandle, + GuiVector2Handle, + GuiVector3Handle, SupportsRemoveProtocol, UploadedFile, _GuiHandleState, @@ -735,7 +743,7 @@ def add_button( ), ), is_button=True, - )._impl + ) ) def add_upload_button( @@ -787,7 +795,7 @@ def add_upload_button( ), ), is_button=True, - )._impl + ) ) # The TLiteralString overload tells pyright to resolve the value type to a Literal @@ -858,7 +866,7 @@ def add_button_group( visible=visible, ), ), - )._impl, + ) ) def add_checkbox( @@ -869,7 +877,7 @@ def add_checkbox( visible: bool = True, hint: str | None = None, order: float | None = None, - ) -> GuiInputHandle[bool]: + ) -> GuiCheckboxHandle: """Add a checkbox to the GUI. Args: @@ -887,20 +895,22 @@ def add_checkbox( assert isinstance(value, bool) id = _make_unique_id() order = _apply_default_order(order) - return self._create_gui_input( - value, - message=_messages.GuiCheckboxMessage( - value=value, - id=id, - container_id=self._get_container_id(), - props=_messages.GuiCheckboxProps( - order=order, - label=label, - hint=hint, - disabled=disabled, - visible=visible, + return GuiCheckboxHandle( + self._create_gui_input( + value, + message=_messages.GuiCheckboxMessage( + value=value, + id=id, + container_id=self._get_container_id(), + props=_messages.GuiCheckboxProps( + order=order, + label=label, + hint=hint, + disabled=disabled, + visible=visible, + ), ), - ), + ) ) def add_text( @@ -911,7 +921,7 @@ def add_text( visible: bool = True, hint: str | None = None, order: float | None = None, - ) -> GuiInputHandle[str]: + ) -> GuiTextHandle: """Add a text input to the GUI. Args: @@ -929,20 +939,22 @@ def add_text( assert isinstance(value, str) id = _make_unique_id() order = _apply_default_order(order) - return self._create_gui_input( - value, - message=_messages.GuiTextMessage( - value=value, - id=id, - container_id=self._get_container_id(), - props=_messages.GuiTextProps( - order=order, - label=label, - hint=hint, - disabled=disabled, - visible=visible, + return GuiTextHandle( + self._create_gui_input( + value, + message=_messages.GuiTextMessage( + value=value, + id=id, + container_id=self._get_container_id(), + props=_messages.GuiTextProps( + order=order, + label=label, + hint=hint, + disabled=disabled, + visible=visible, + ), ), - ), + ) ) def add_number( @@ -956,7 +968,7 @@ def add_number( visible: bool = True, hint: str | None = None, order: float | None = None, - ) -> GuiInputHandle[IntOrFloat]: + ) -> GuiNumberHandle[IntOrFloat]: """Add a number input to the GUI, with user-specifiable bound and precision parameters. Args: @@ -995,25 +1007,27 @@ def add_number( id = _make_unique_id() order = _apply_default_order(order) - return self._create_gui_input( - value, - message=_messages.GuiNumberMessage( - value=value, - id=id, - container_id=self._get_container_id(), - props=_messages.GuiNumberProps( - order=order, - label=label, - hint=hint, - min=min, - max=max, - precision=_compute_precision_digits(step), - step=step, - disabled=disabled, - visible=visible, + return GuiNumberHandle( + self._create_gui_input( + value, + message=_messages.GuiNumberMessage( + value=value, + id=id, + container_id=self._get_container_id(), + props=_messages.GuiNumberProps( + order=order, + label=label, + hint=hint, + min=min, + max=max, + precision=_compute_precision_digits(step), + step=step, + disabled=disabled, + visible=visible, + ), ), - ), - is_button=False, + is_button=False, + ) ) def add_vector2( @@ -1027,7 +1041,7 @@ def add_vector2( visible: bool = True, hint: str | None = None, order: float | None = None, - ) -> GuiInputHandle[tuple[float, float]]: + ) -> GuiVector2Handle: """Add a length-2 vector input to the GUI. Args: @@ -1060,24 +1074,26 @@ def add_vector2( possible_steps.extend([_compute_step(x) for x in max]) step = float(onp.min(possible_steps)) - return self._create_gui_input( - value, - message=_messages.GuiVector2Message( - value=value, - id=id, - container_id=self._get_container_id(), - props=_messages.GuiVector2Props( - order=order, - label=label, - hint=hint, - min=min, - max=max, - step=step, - precision=_compute_precision_digits(step), - disabled=disabled, - visible=visible, + return GuiVector2Handle( + self._create_gui_input( + value, + message=_messages.GuiVector2Message( + value=value, + id=id, + container_id=self._get_container_id(), + props=_messages.GuiVector2Props( + order=order, + label=label, + hint=hint, + min=min, + max=max, + step=step, + precision=_compute_precision_digits(step), + disabled=disabled, + visible=visible, + ), ), - ), + ) ) def add_vector3( @@ -1091,7 +1107,7 @@ def add_vector3( visible: bool = True, hint: str | None = None, order: float | None = None, - ) -> GuiInputHandle[tuple[float, float, float]]: + ) -> GuiVector3Handle: """Add a length-3 vector input to the GUI. Args: @@ -1124,24 +1140,26 @@ def add_vector3( possible_steps.extend([_compute_step(x) for x in max]) step = float(onp.min(possible_steps)) - return self._create_gui_input( - value, - message=_messages.GuiVector3Message( - value=value, - id=id, - container_id=self._get_container_id(), - props=_messages.GuiVector3Props( - order=order, - label=label, - hint=hint, - min=min, - max=max, - step=step, - precision=_compute_precision_digits(step), - disabled=disabled, - visible=visible, + return GuiVector3Handle( + self._create_gui_input( + value, + message=_messages.GuiVector3Message( + value=value, + id=id, + container_id=self._get_container_id(), + props=_messages.GuiVector3Props( + order=order, + label=label, + hint=hint, + min=min, + max=max, + step=step, + precision=_compute_precision_digits(step), + disabled=disabled, + visible=visible, + ), ), - ), + ) ) # See add_dropdown for notes on overloads. @@ -1214,7 +1232,7 @@ def add_dropdown( visible=visible, ), ), - )._impl, + ), ) def add_progress_bar( @@ -1277,7 +1295,7 @@ def add_slider( visible: bool = True, hint: str | None = None, order: float | None = None, - ) -> GuiInputHandle[IntOrFloat]: + ) -> GuiSliderHandle[IntOrFloat]: """Add a slider to the GUI. Types of the min, max, step, and initial value should match. Args: @@ -1318,33 +1336,35 @@ def add_slider( id = _make_unique_id() order = _apply_default_order(order) - return self._create_gui_input( - value, - message=_messages.GuiSliderMessage( - value=value, - id=id, - container_id=self._get_container_id(), - props=_messages.GuiSliderProps( - order=order, - label=label, - hint=hint, - min=min, - max=max, - step=step, - precision=_compute_precision_digits(step), - visible=visible, - disabled=disabled, - _marks=tuple( - GuiSliderMark(value=float(x[0]), label=x[1]) - if isinstance(x, tuple) - else GuiSliderMark(value=x, label=None) - for x in marks - ) - if marks is not None - else None, + return GuiSliderHandle( + self._create_gui_input( + value, + message=_messages.GuiSliderMessage( + value=value, + id=id, + container_id=self._get_container_id(), + props=_messages.GuiSliderProps( + order=order, + label=label, + hint=hint, + min=min, + max=max, + step=step, + precision=_compute_precision_digits(step), + visible=visible, + disabled=disabled, + _marks=tuple( + GuiSliderMark(value=float(x[0]), label=x[1]) + if isinstance(x, tuple) + else GuiSliderMark(value=x, label=None) + for x in marks + ) + if marks is not None + else None, + ), ), - ), - is_button=False, + is_button=False, + ) ) def add_multi_slider( @@ -1361,7 +1381,7 @@ def add_multi_slider( visible: bool = True, hint: str | None = None, order: float | None = None, - ) -> GuiInputHandle[tuple[IntOrFloat, ...]]: + ) -> GuiMultiSliderHandle[tuple[IntOrFloat, ...]]: """Add a multi slider to the GUI. Types of the min, max, step, and initial value should match. Args: @@ -1401,35 +1421,37 @@ def add_multi_slider( id = _make_unique_id() order = _apply_default_order(order) - return self._create_gui_input( - value=initial_value, - message=_messages.GuiMultiSliderMessage( + return GuiMultiSliderHandle( + self._create_gui_input( value=initial_value, - id=id, - container_id=self._get_container_id(), - props=_messages.GuiMultiSliderProps( - order=order, - label=label, - hint=hint, - min=min, - min_range=min_range, - max=max, - step=step, - visible=visible, - disabled=disabled, - fixed_endpoints=fixed_endpoints, - precision=_compute_precision_digits(step), - _marks=tuple( - GuiSliderMark(value=float(x[0]), label=x[1]) - if isinstance(x, tuple) - else GuiSliderMark(value=x, label=None) - for x in marks - ) - if marks is not None - else None, + message=_messages.GuiMultiSliderMessage( + value=initial_value, + id=id, + container_id=self._get_container_id(), + props=_messages.GuiMultiSliderProps( + order=order, + label=label, + hint=hint, + min=min, + min_range=min_range, + max=max, + step=step, + visible=visible, + disabled=disabled, + fixed_endpoints=fixed_endpoints, + precision=_compute_precision_digits(step), + _marks=tuple( + GuiSliderMark(value=float(x[0]), label=x[1]) + if isinstance(x, tuple) + else GuiSliderMark(value=x, label=None) + for x in marks + ) + if marks is not None + else None, + ), ), - ), - is_button=False, + is_button=False, + ) ) def add_rgb( @@ -1440,7 +1462,7 @@ def add_rgb( visible: bool = True, hint: str | None = None, order: float | None = None, - ) -> GuiInputHandle[tuple[int, int, int]]: + ) -> GuiRgbHandle: """Add an RGB picker to the GUI. Args: @@ -1458,20 +1480,22 @@ def add_rgb( value = initial_value id = _make_unique_id() order = _apply_default_order(order) - return self._create_gui_input( - value, - message=_messages.GuiRgbMessage( - value=value, - id=id, - container_id=self._get_container_id(), - props=_messages.GuiRgbProps( - order=order, - label=label, - hint=hint, - disabled=disabled, - visible=visible, + return GuiRgbHandle( + self._create_gui_input( + value, + message=_messages.GuiRgbMessage( + value=value, + id=id, + container_id=self._get_container_id(), + props=_messages.GuiRgbProps( + order=order, + label=label, + hint=hint, + disabled=disabled, + visible=visible, + ), ), - ), + ) ) def add_rgba( @@ -1482,7 +1506,7 @@ def add_rgba( visible: bool = True, hint: str | None = None, order: float | None = None, - ) -> GuiInputHandle[tuple[int, int, int, int]]: + ) -> GuiRgbaHandle: """Add an RGBA picker to the GUI. Args: @@ -1499,20 +1523,22 @@ def add_rgba( value = initial_value id = _make_unique_id() order = _apply_default_order(order) - return self._create_gui_input( - value, - message=_messages.GuiRgbaMessage( - value=value, - id=id, - container_id=self._get_container_id(), - props=_messages.GuiRgbaProps( - order=order, - label=label, - hint=hint, - disabled=disabled, - visible=visible, + return GuiRgbaHandle( + self._create_gui_input( + value, + message=_messages.GuiRgbaMessage( + value=value, + id=id, + container_id=self._get_container_id(), + props=_messages.GuiRgbaProps( + order=order, + label=label, + hint=hint, + disabled=disabled, + visible=visible, + ), ), - ), + ) ) class GuiMessage(Protocol[GuiInputPropsType]): @@ -1524,7 +1550,7 @@ def _create_gui_input( value: T, message: GuiMessage, is_button: bool = False, - ) -> GuiInputHandle[T]: + ) -> _GuiHandleState[T]: """Private helper for adding a simple GUI element.""" # Send add GUI input message. @@ -1557,6 +1583,4 @@ def sync_other_clients( handle_state.sync_cb = sync_other_clients - handle = GuiInputHandle(handle_state) - - return handle + return handle_state diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 87cf7bf90..48f66ae2d 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -15,6 +15,7 @@ Dict, Generic, Iterable, + Tuple, TypeVar, cast, get_type_hints, @@ -29,14 +30,23 @@ from ._icons_enum import IconName from ._messages import ( GuiBaseProps, + GuiCheckboxProps, GuiCloseModalMessage, GuiDropdownProps, GuiFolderProps, GuiMarkdownProps, + GuiMultiSliderProps, + GuiNumberProps, GuiPlotlyProps, GuiProgressBarProps, GuiRemoveMessage, + GuiRgbaProps, + GuiRgbProps, + GuiSliderProps, + GuiTextProps, GuiUpdateMessage, + GuiVector2Props, + GuiVector3Props, ) from ._scene_api import _encode_image_binary from .infra import ClientId @@ -225,6 +235,36 @@ def on_update( return func +class GuiCheckboxHandle(GuiInputHandle[bool], GuiCheckboxProps): ... + + +class GuiTextHandle(GuiInputHandle[str], GuiTextProps): ... + + +IntOrFloat = TypeVar("IntOrFloat", int, float) + + +class GuiNumberHandle(GuiInputHandle[T], Generic[T], GuiNumberProps): ... + + +class GuiSliderHandle(GuiInputHandle[T], Generic[T], GuiSliderProps): ... + + +class GuiMultiSliderHandle(GuiInputHandle[T], Generic[T], GuiMultiSliderProps): ... + + +class GuiRgbHandle(GuiInputHandle[Tuple[int, int, int]], GuiRgbProps): ... + + +class GuiRgbaHandle(GuiInputHandle[Tuple[int, int, int, int]], GuiRgbaProps): ... + + +class GuiVector2Handle(GuiInputHandle[Tuple[float, float]], GuiVector2Props): ... + + +class GuiVector3Handle(GuiInputHandle[Tuple[float, float, float]], GuiVector3Props): ... + + @dataclasses.dataclass(frozen=True) class GuiEvent(Generic[TGuiHandle]): """Information associated with a GUI event, such as an update or click. From 1cf296122d418399f953554e75fd8fbcdd78a684 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 24 Sep 2024 00:58:46 -0700 Subject: [PATCH 05/15] Minor fixes --- examples/02_gui.py | 1 - examples/25_smpl_visualizer_skinned.py | 4 +++- src/viser/_messages.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/02_gui.py b/examples/02_gui.py index e48e63fd7..e13228359 100644 --- a/examples/02_gui.py +++ b/examples/02_gui.py @@ -5,7 +5,6 @@ import time import numpy as onp - import viser diff --git a/examples/25_smpl_visualizer_skinned.py b/examples/25_smpl_visualizer_skinned.py index 00f10923d..db6c412e2 100644 --- a/examples/25_smpl_visualizer_skinned.py +++ b/examples/25_smpl_visualizer_skinned.py @@ -21,7 +21,6 @@ import numpy as np import numpy as onp import tyro - import viser import viser.transforms as tf @@ -122,6 +121,9 @@ def main(model_path: Path) -> None: gui_elements.changed = False + # Render as wireframe? + skinned_handle.wireframe = gui_elements.gui_wireframe.value + # Compute SMPL outputs. smpl_outputs = model.get_outputs( betas=np.array([x.value for x in gui_elements.gui_betas]), diff --git a/src/viser/_messages.py b/src/viser/_messages.py index c8d0eaae1..a412a9901 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -778,7 +778,7 @@ class ResetGuiMessage(Message): @dataclasses.dataclass -class GuiBaseProps(Message): +class GuiBaseProps: """Base message type containing fields commonly used by GUI inputs.""" order: float From 33759ce705a81c3292baddbe83832d0ae4116416 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 24 Sep 2024 01:01:52 -0700 Subject: [PATCH 06/15] ruff --- examples/02_gui.py | 1 + examples/25_smpl_visualizer_skinned.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/02_gui.py b/examples/02_gui.py index e13228359..e48e63fd7 100644 --- a/examples/02_gui.py +++ b/examples/02_gui.py @@ -5,6 +5,7 @@ import time import numpy as onp + import viser diff --git a/examples/25_smpl_visualizer_skinned.py b/examples/25_smpl_visualizer_skinned.py index db6c412e2..d7b5b4e43 100644 --- a/examples/25_smpl_visualizer_skinned.py +++ b/examples/25_smpl_visualizer_skinned.py @@ -21,6 +21,7 @@ import numpy as np import numpy as onp import tyro + import viser import viser.transforms as tf From b271c086f929eba94010a38805910b6f821b4091 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 24 Sep 2024 01:21:47 -0700 Subject: [PATCH 07/15] Docstrings --- src/viser/_gui_api.py | 35 +- src/viser/_gui_handles.py | 3 +- src/viser/_messages.py | 269 ++++++++----- src/viser/client/src/WebsocketMessages.ts | 372 +----------------- src/viser/client/src/components/Button.tsx | 3 +- .../client/src/components/UploadButton.tsx | 2 +- 6 files changed, 177 insertions(+), 507 deletions(-) diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index 712f6ba4f..c2c00f845 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -737,7 +737,7 @@ def add_button( label=label, hint=hint, color=color, - icon_html=None if icon is None else svg_from_icon(icon), + _icon_html=None if icon is None else svg_from_icon(icon), disabled=disabled, visible=visible, ), @@ -791,49 +791,22 @@ def add_upload_button( hint=hint, color=color, mime_type=mime_type, - icon_html=None if icon is None else svg_from_icon(icon), + _icon_html=None if icon is None else svg_from_icon(icon), ), ), is_button=True, ) ) - # The TLiteralString overload tells pyright to resolve the value type to a Literal - # whenever possible. - # - # TString is helpful when the input types are generic (could be str, could be - # Literal). - @overload - def add_button_group( - self, - label: str, - options: Sequence[TLiteralString], - visible: bool = True, - disabled: bool = False, - hint: str | None = None, - order: float | None = None, - ) -> GuiButtonGroupHandle[TLiteralString]: ... - - @overload def add_button_group( self, label: str, - options: Sequence[TString], - visible: bool = True, - disabled: bool = False, - hint: str | None = None, - order: float | None = None, - ) -> GuiButtonGroupHandle[TString]: ... - - def add_button_group( - self, - label: str, - options: Sequence[TLiteralString] | Sequence[TString], + options: Sequence[str], visible: bool = True, disabled: bool = False, hint: str | None = None, order: float | None = None, - ) -> GuiButtonGroupHandle[Any]: # Return types are specified in overloads. + ) -> GuiButtonGroupHandle: """Add a button group to the GUI. Args: diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 48f66ae2d..0bc96e82b 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -30,6 +30,7 @@ from ._icons_enum import IconName from ._messages import ( GuiBaseProps, + GuiButtonGroupProps, GuiCheckboxProps, GuiCloseModalMessage, GuiDropdownProps, @@ -316,7 +317,7 @@ def on_upload( return func -class GuiButtonGroupHandle(_GuiInputHandle[StringType], Generic[StringType]): +class GuiButtonGroupHandle(_GuiInputHandle[str], GuiButtonGroupProps): """Handle for a button group input in our visualizer. Lets us detect clicks.""" diff --git a/src/viser/_messages.py b/src/viser/_messages.py index a412a9901..45c8f2450 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -102,17 +102,17 @@ class NotificationMessage(Message): @dataclasses.dataclass class NotificationProps: title: str - """Title of the notification. For handles, synchronized automatically when assigned.""" + """Title of the notification. Synchronized automatically when assigned.""" body: str - """Body text of the notification. For handles, synchronized automatically when assigned.""" + """Body text of the notification. Synchronized automatically when assigned.""" loading: bool - """Whether to show a loading indicator. For handles, synchronized automatically when assigned.""" + """Whether to show a loading indicator. Synchronized automatically when assigned.""" with_close_button: bool - """Whether to show a close button. For handles, synchronized automatically when assigned.""" + """Whether to show a close button. Synchronized automatically when assigned.""" auto_close: Union[int, Literal[False]] - """Time in milliseconds after which the notification should auto-close, or False to disable auto-close. For handles, synchronized automatically when assigned.""" + """Time in milliseconds after which the notification should auto-close, or False to disable auto-close. Synchronized automatically when assigned.""" color: Optional[Color] - """Color of the notification. For handles, synchronized automatically when assigned.""" + """Color of the notification. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -180,17 +180,17 @@ class CameraFrustumMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class CameraFrustumProps: fov: float - """Field of view of the camera (in radians). For handles, synchronized automatically when assigned.""" + """Field of view of the camera (in radians). Synchronized automatically when assigned.""" aspect: float - """Aspect ratio of the camera (width over height). For handles, synchronized automatically when assigned.""" + """Aspect ratio of the camera (width over height). Synchronized automatically when assigned.""" scale: float - """Scale factor for the size of the frustum. For handles, synchronized automatically when assigned.""" + """Scale factor for the size of the frustum. Synchronized automatically when assigned.""" color: Tuple[int, int, int] - """Color of the frustum as RGB integers. For handles, synchronized automatically when assigned.""" + """Color of the frustum as RGB integers. Synchronized automatically when assigned.""" image_media_type: Optional[Literal["image/jpeg", "image/png"]] - """Format of the provided image ('image/jpeg' or 'image/png'). For handles, synchronized automatically when assigned.""" + """Format of the provided image ('image/jpeg' or 'image/png'). Synchronized automatically when assigned.""" image_binary: Optional[bytes] - """Optional image to be displayed on the frustum. For handles, synchronized automatically when assigned.""" + """Optional image to be displayed on the frustum. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -204,9 +204,9 @@ class GlbMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class GlbProps: glb_data: bytes - """A binary payload containing the GLB data. For handles, synchronized automatically when assigned.""" + """A binary payload containing the GLB data. Synchronized automatically when assigned.""" scale: float - """A scale for resizing the GLB asset. For handles, synchronized automatically when assigned.""" + """A scale for resizing the GLB asset. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -220,13 +220,13 @@ class FrameMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class FrameProps: show_axes: bool - """Boolean to indicate whether to show the frame as a set of axes + origin sphere. For handles, synchronized automatically when assigned.""" + """Boolean to indicate whether to show the frame as a set of axes + origin sphere. Synchronized automatically when assigned.""" axes_length: float - """Length of each axis. For handles, synchronized automatically when assigned.""" + """Length of each axis. Synchronized automatically when assigned.""" axes_radius: float - """Radius of each axis. For handles, synchronized automatically when assigned.""" + """Radius of each axis. Synchronized automatically when assigned.""" origin_radius: float - """Radius of the origin sphere. For handles, synchronized automatically when assigned.""" + """Radius of the origin sphere. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -243,13 +243,13 @@ class BatchedAxesMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class BatchedAxesProps: wxyzs_batched: onpt.NDArray[onp.float32] - """Float array of shape (N,4) representing quaternion rotations. For handles, synchronized automatically when assigned.""" + """Float array of shape (N,4) representing quaternion rotations. Synchronized automatically when assigned.""" positions_batched: onpt.NDArray[onp.float32] - """Float array of shape (N,3) representing positions. For handles, synchronized automatically when assigned.""" + """Float array of shape (N,3) representing positions. Synchronized automatically when assigned.""" axes_length: float - """Length of each axis. For handles, synchronized automatically when assigned.""" + """Length of each axis. Synchronized automatically when assigned.""" axes_radius: float - """Radius of each axis. For handles, synchronized automatically when assigned.""" + """Radius of each axis. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -263,27 +263,27 @@ class GridMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class GridProps: width: float - """Width of the grid. For handles, synchronized automatically when assigned.""" + """Width of the grid. Synchronized automatically when assigned.""" height: float - """Height of the grid. For handles, synchronized automatically when assigned.""" + """Height of the grid. Synchronized automatically when assigned.""" width_segments: int - """Number of segments along the width. For handles, synchronized automatically when assigned.""" + """Number of segments along the width. Synchronized automatically when assigned.""" height_segments: int - """Number of segments along the height. For handles, synchronized automatically when assigned.""" + """Number of segments along the height. Synchronized automatically when assigned.""" plane: Literal["xz", "xy", "yx", "yz", "zx", "zy"] - """The plane in which the grid is oriented. For handles, synchronized automatically when assigned.""" + """The plane in which the grid is oriented. Synchronized automatically when assigned.""" cell_color: Tuple[int, int, int] - """Color of the grid cells as RGB integers. For handles, synchronized automatically when assigned.""" + """Color of the grid cells as RGB integers. Synchronized automatically when assigned.""" cell_thickness: float - """Thickness of the grid lines. For handles, synchronized automatically when assigned.""" + """Thickness of the grid lines. Synchronized automatically when assigned.""" cell_size: float - """Size of each cell in the grid. For handles, synchronized automatically when assigned.""" + """Size of each cell in the grid. Synchronized automatically when assigned.""" section_color: Tuple[int, int, int] - """Color of the grid sections as RGB integers. For handles, synchronized automatically when assigned.""" + """Color of the grid sections as RGB integers. Synchronized automatically when assigned.""" section_thickness: float - """Thickness of the section lines. For handles, synchronized automatically when assigned.""" + """Thickness of the section lines. Synchronized automatically when assigned.""" section_size: float - """Size of each section in the grid. For handles, synchronized automatically when assigned.""" + """Size of each section in the grid. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -297,7 +297,7 @@ class LabelMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class LabelProps: text: str - """Text content of the label. For handles, synchronized automatically when assigned.""" + """Text content of the label. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -311,9 +311,9 @@ class Gui3DMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class Gui3DProps: order: float - """Order value for arranging GUI elements. For handles, synchronized automatically when assigned.""" + """Order value for arranging GUI elements. Synchronized automatically when assigned.""" container_id: str - """Identifier for the container. For handles, synchronized automatically when assigned.""" + """Identifier for the container. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -332,13 +332,13 @@ class PointCloudMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class PointCloudProps: points: onpt.NDArray[onp.float32] - """Location of points. Should have shape (N, 3). For handles, synchronized automatically when assigned.""" + """Location of points. Should have shape (N, 3). Synchronized automatically when assigned.""" colors: onpt.NDArray[onp.uint8] - """Colors of points. Should have shape (N, 3) or (3,). For handles, synchronized automatically when assigned.""" + """Colors of points. Should have shape (N, 3) or (3,). Synchronized automatically when assigned.""" point_size: float - """Size of each point. For handles, synchronized automatically when assigned.""" + """Size of each point. Synchronized automatically when assigned.""" point_ball_norm: float - """Norm value determining the shape of each point. For handles, synchronized automatically when assigned.""" + """Norm value determining the shape of each point. Synchronized automatically when assigned.""" def __post_init__(self): # Check shapes. @@ -362,9 +362,9 @@ class DirectionalLightMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class DirectionalLightProps: color: Tuple[int, int, int] - """Color of the directional light. For handles, synchronized automatically when assigned.""" + """Color of the directional light. Synchronized automatically when assigned.""" intensity: float - """Intensity of the directional light. For handles, synchronized automatically when assigned.""" + """Intensity of the directional light. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -378,9 +378,9 @@ class AmbientLightMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class AmbientLightProps: color: Tuple[int, int, int] - """Color of the ambient light. For handles, synchronized automatically when assigned.""" + """Color of the ambient light. Synchronized automatically when assigned.""" intensity: float - """Intensity of the ambient light. For handles, synchronized automatically when assigned.""" + """Intensity of the ambient light. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -394,11 +394,11 @@ class HemisphereLightMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class HemisphereLightProps: sky_color: Tuple[int, int, int] - """Sky color of the hemisphere light. For handles, synchronized automatically when assigned.""" + """Sky color of the hemisphere light. Synchronized automatically when assigned.""" ground_color: Tuple[int, int, int] - """Ground color of the hemisphere light. For handles, synchronized automatically when assigned.""" + """Ground color of the hemisphere light. Synchronized automatically when assigned.""" intensity: float - """Intensity of the hemisphere light. For handles, synchronized automatically when assigned.""" + """Intensity of the hemisphere light. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -412,13 +412,13 @@ class PointLightMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class PointLightProps: color: Tuple[int, int, int] - """Color of the point light. For handles, synchronized automatically when assigned.""" + """Color of the point light. Synchronized automatically when assigned.""" intensity: float - """Intensity of the point light. For handles, synchronized automatically when assigned.""" + """Intensity of the point light. Synchronized automatically when assigned.""" distance: float - """Distance of the point light. For handles, synchronized automatically when assigned.""" + """Distance of the point light. Synchronized automatically when assigned.""" decay: float - """Decay of the point light. For handles, synchronized automatically when assigned.""" + """Decay of the point light. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -432,13 +432,13 @@ class RectAreaLightMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class RectAreaLightProps: color: Tuple[int, int, int] - """Color of the rectangular area light. For handles, synchronized automatically when assigned.""" + """Color of the rectangular area light. Synchronized automatically when assigned.""" intensity: float - """Intensity of the rectangular area light. For handles, synchronized automatically when assigned.""" + """Intensity of the rectangular area light. Synchronized automatically when assigned.""" width: float - """Width of the rectangular area light. For handles, synchronized automatically when assigned.""" + """Width of the rectangular area light. Synchronized automatically when assigned.""" height: float - """Height of the rectangular area light. For handles, synchronized automatically when assigned.""" + """Height of the rectangular area light. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -452,17 +452,17 @@ class SpotLightMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class SpotLightProps: color: Tuple[int, int, int] - """Color of the spot light. For handles, synchronized automatically when assigned.""" + """Color of the spot light. Synchronized automatically when assigned.""" intensity: float - """Intensity of the spot light. For handles, synchronized automatically when assigned.""" + """Intensity of the spot light. Synchronized automatically when assigned.""" distance: float - """Distance of the spot light. For handles, synchronized automatically when assigned.""" + """Distance of the spot light. Synchronized automatically when assigned.""" angle: float - """Angle of the spot light. For handles, synchronized automatically when assigned.""" + """Angle of the spot light. Synchronized automatically when assigned.""" penumbra: float - """Penumbra of the spot light. For handles, synchronized automatically when assigned.""" + """Penumbra of the spot light. Synchronized automatically when assigned.""" decay: float - """Decay of the spot light. For handles, synchronized automatically when assigned.""" + """Decay of the spot light. Synchronized automatically when assigned.""" def __post_init__(self): assert self.angle <= onp.pi / 2 @@ -515,23 +515,23 @@ class MeshMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class MeshProps: vertices: onpt.NDArray[onp.float32] - """A numpy array of vertex positions. Should have shape (V, 3). For handles, synchronized automatically when assigned.""" + """A numpy array of vertex positions. Should have shape (V, 3). Synchronized automatically when assigned.""" faces: onpt.NDArray[onp.uint32] - """A numpy array of faces, where each face is represented by indices of vertices. Should have shape (F, 3). For handles, synchronized automatically when assigned.""" + """A numpy array of faces, where each face is represented by indices of vertices. Should have shape (F, 3). Synchronized automatically when assigned.""" color: Optional[Tuple[int, int, int]] - """Color of the mesh as RGB integers. For handles, synchronized automatically when assigned.""" + """Color of the mesh as RGB integers. Synchronized automatically when assigned.""" vertex_colors: Optional[onpt.NDArray[onp.uint8]] - """Optional array of vertex colors. For handles, synchronized automatically when assigned.""" + """Optional array of vertex colors. Synchronized automatically when assigned.""" wireframe: bool - """Boolean indicating if the mesh should be rendered as a wireframe. For handles, synchronized automatically when assigned.""" + """Boolean indicating if the mesh should be rendered as a wireframe. Synchronized automatically when assigned.""" opacity: Optional[float] - """Opacity of the mesh. None means opaque. For handles, synchronized automatically when assigned.""" + """Opacity of the mesh. None means opaque. Synchronized automatically when assigned.""" flat_shading: bool - """Whether to do flat shading. For handles, synchronized automatically when assigned.""" + """Whether to do flat shading. Synchronized automatically when assigned.""" side: Literal["front", "back", "double"] - """Side of the surface to render. For handles, synchronized automatically when assigned.""" + """Side of the surface to render. Synchronized automatically when assigned.""" material: Literal["standard", "toon3", "toon5"] - """Material type of the mesh. For handles, synchronized automatically when assigned.""" + """Material type of the mesh. Synchronized automatically when assigned.""" def __post_init__(self): # Check shapes. @@ -554,13 +554,13 @@ class SkinnedMeshProps(MeshProps): Vertices are internally canonicalized to float32, faces to uint32.""" bone_wxyzs: Tuple[Tuple[float, float, float, float], ...] - """Tuple of quaternions representing bone orientations. For handles, synchronized automatically when assigned.""" + """Tuple of quaternions representing bone orientations. Synchronized automatically when assigned.""" bone_positions: Tuple[Tuple[float, float, float], ...] - """Tuple of positions representing bone positions. For handles, synchronized automatically when assigned.""" + """Tuple of positions representing bone positions. Synchronized automatically when assigned.""" skin_indices: onpt.NDArray[onp.uint16] - """Array of skin indices. Should have shape (V, 4). For handles, synchronized automatically when assigned.""" + """Array of skin indices. Should have shape (V, 4). Synchronized automatically when assigned.""" skin_weights: onpt.NDArray[onp.float32] - """Array of skin weights. Should have shape (V, 4). For handles, synchronized automatically when assigned.""" + """Array of skin weights. Should have shape (V, 4). Synchronized automatically when assigned.""" def __post_init__(self): # Check shapes. @@ -615,33 +615,33 @@ class TransformControlsMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class TransformControlsProps: scale: float - """Scale of the transform controls. For handles, synchronized automatically when assigned.""" + """Scale of the transform controls. Synchronized automatically when assigned.""" line_width: float - """Width of the lines used in the gizmo. For handles, synchronized automatically when assigned.""" + """Width of the lines used in the gizmo. Synchronized automatically when assigned.""" fixed: bool - """Boolean indicating if the gizmo should be fixed in position. For handles, synchronized automatically when assigned.""" + """Boolean indicating if the gizmo should be fixed in position. Synchronized automatically when assigned.""" auto_transform: bool - """Whether the transform should be applied automatically. For handles, synchronized automatically when assigned.""" + """Whether the transform should be applied automatically. Synchronized automatically when assigned.""" active_axes: Tuple[bool, bool, bool] - """Tuple of booleans indicating active axes. For handles, synchronized automatically when assigned.""" + """Tuple of booleans indicating active axes. Synchronized automatically when assigned.""" disable_axes: bool - """Boolean to disable axes interaction. For handles, synchronized automatically when assigned.""" + """Boolean to disable axes interaction. Synchronized automatically when assigned.""" disable_sliders: bool - """Boolean to disable slider interaction. For handles, synchronized automatically when assigned.""" + """Boolean to disable slider interaction. Synchronized automatically when assigned.""" disable_rotations: bool - """Boolean to disable rotation interaction. For handles, synchronized automatically when assigned.""" + """Boolean to disable rotation interaction. Synchronized automatically when assigned.""" translation_limits: Tuple[ Tuple[float, float], Tuple[float, float], Tuple[float, float] ] - """Limits for translation. For handles, synchronized automatically when assigned.""" + """Limits for translation. Synchronized automatically when assigned.""" rotation_limits: Tuple[ Tuple[float, float], Tuple[float, float], Tuple[float, float] ] - """Limits for rotation. For handles, synchronized automatically when assigned.""" + """Limits for rotation. Synchronized automatically when assigned.""" depth_test: bool - """Boolean indicating if depth testing should be used when rendering. For handles, synchronized automatically when assigned.""" + """Boolean indicating if depth testing should be used when rendering. Synchronized automatically when assigned.""" opacity: float - """Opacity of the gizmo. For handles, synchronized automatically when assigned.""" + """Opacity of the gizmo. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -723,13 +723,13 @@ class ImageMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class ImageProps: media_type: Literal["image/jpeg", "image/png"] - """Format of the provided image ('image/jpeg' or 'image/png'). For handles, synchronized automatically when assigned.""" + """Format of the provided image ('image/jpeg' or 'image/png'). Synchronized automatically when assigned.""" data: bytes - """Binary data of the image. For handles, synchronized automatically when assigned.""" + """Binary data of the image. Synchronized automatically when assigned.""" render_width: float - """Width at which the image should be rendered in the scene. For handles, synchronized automatically when assigned.""" + """Width at which the image should be rendered in the scene. Synchronized automatically when assigned.""" render_height: float - """Height at which the image should be rendered in the scene. For handles, synchronized automatically when assigned.""" + """Height at which the image should be rendered in the scene. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -782,18 +782,27 @@ class GuiBaseProps: """Base message type containing fields commonly used by GUI inputs.""" order: float + """Order value for arranging GUI elements. Synchronized automatically when assigned.""" label: str + """Label text for the GUI element. Synchronized automatically when assigned.""" hint: Optional[str] + """Optional hint text for the GUI element. Synchronized automatically when assigned.""" visible: bool + """Visibility state of the GUI element. Synchronized automatically when assigned.""" disabled: bool + """Disabled state of the GUI element. Synchronized automatically when assigned.""" @dataclasses.dataclass class GuiFolderProps: order: float + """Order value for arranging GUI elements. Synchronized automatically when assigned.""" label: str + """Label text for the GUI folder. Synchronized automatically when assigned.""" visible: bool + """Visibility state of the GUI folder. Synchronized automatically when assigned.""" expand_by_default: bool + """Whether the folder should be expanded by default. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -806,8 +815,11 @@ class GuiFolderMessage(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiMarkdownProps: order: float + """Order value for arranging GUI elements. Synchronized automatically when assigned.""" _markdown: str + """(Private) Markdown content to be displayed. Synchronized automatically when assigned.""" visible: bool + """Visibility state of the markdown element. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -820,9 +832,17 @@ class GuiMarkdownMessage(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiProgressBarProps: order: float + """Order value for arranging GUI elements. Synchronized automatically when assigned.""" animated: bool + """Whether the progress bar should be animated. Synchronized automatically when assigned.""" color: Optional[Color] + """Color of the progress bar. Synchronized automatically when assigned.""" visible: bool + """Visibility state of the progress bar. Synchronized automatically when assigned.""" + label: str + """Label text for the progress bar. Synchronized automatically when assigned.""" + value: float + """Current value of the progress bar. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -836,9 +856,13 @@ class GuiProgressBarMessage(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiPlotlyProps: order: float + """Order value for arranging GUI elements. Synchronized automatically when assigned.""" _plotly_json_str: str + """(Private) JSON string representation of the Plotly figure. Synchronized automatically when assigned.""" aspect: float + """Aspect ratio of the plot. Synchronized automatically when assigned.""" visible: bool + """Visibility state of the plot. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -850,6 +874,7 @@ class GuiPlotlyMessage(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiTabGroupProps: + # Note: for tab groups we currently don't expose properties automatically. tab_labels: Tuple[str, ...] tab_icons_html: Tuple[Union[str, None], ...] tab_container_ids: Tuple[str, ...] @@ -879,7 +904,9 @@ class GuiCloseModalMessage(Message): @dataclasses.dataclass class GuiButtonProps(GuiBaseProps): color: Optional[Color] - icon_html: Optional[str] + """Color of the button. Synchronized automatically when assigned.""" + _icon_html: Optional[str] + """(Private) HTML string for the icon to be displayed on the button. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -893,8 +920,11 @@ class GuiButtonMessage(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiUploadButtonProps(GuiBaseProps): color: Optional[Color] - icon_html: Optional[str] + """Color of the upload button. Synchronized automatically when assigned.""" + _icon_html: Optional[str] + """(Private) HTML string for the icon to be displayed on the upload button. Synchronized automatically when assigned.""" mime_type: str + """MIME type of the files that can be uploaded. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -907,10 +937,15 @@ class GuiUploadButtonMessage(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiSliderProps(GuiBaseProps): min: float + """Minimum value for the slider. Synchronized automatically when assigned.""" max: float - step: Optional[float] + """Maximum value for the slider. Synchronized automatically when assigned.""" + step: float + """Step size for the slider. Synchronized automatically when assigned.""" precision: int + """Number of decimal places to display for the slider value. Synchronized automatically when assigned.""" _marks: Optional[Tuple[GuiSliderMark, ...]] = None + """(Private) Optional tuple of GuiSliderMark objects to display custom marks on the slider. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -924,12 +959,19 @@ class GuiSliderMessage(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiMultiSliderProps(GuiBaseProps): min: float + """Minimum value for the multi-slider. Synchronized automatically when assigned.""" max: float - step: Optional[float] + """Maximum value for the multi-slider. Synchronized automatically when assigned.""" + step: float + """Step size for the multi-slider. Synchronized automatically when assigned.""" min_range: Optional[float] + """Minimum allowed range between slider handles. Synchronized automatically when assigned.""" precision: int + """Number of decimal places to display for the multi-slider values. Synchronized automatically when assigned.""" fixed_endpoints: bool = False + """If True, the first and last handles cannot be moved. Synchronized automatically when assigned.""" _marks: Optional[Tuple[GuiSliderMark, ...]] = None + """(Private) Optional tuple of GuiSliderMark objects to display custom marks on the multi-slider. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -943,9 +985,13 @@ class GuiMultiSliderMessage(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiNumberProps(GuiBaseProps): precision: int + """Number of decimal places to display for the number value. Synchronized automatically when assigned.""" step: float + """Step size for incrementing/decrementing the number value. Synchronized automatically when assigned.""" min: Optional[float] + """Minimum allowed value for the number input. Synchronized automatically when assigned.""" max: Optional[float] + """Maximum allowed value for the number input. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -998,9 +1044,13 @@ class GuiCheckboxMessage(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiVector2Props(GuiBaseProps): min: Optional[Tuple[float, float]] + """Minimum allowed values for each component of the vector. Synchronized automatically when assigned.""" max: Optional[Tuple[float, float]] + """Maximum allowed values for each component of the vector. Synchronized automatically when assigned.""" step: float + """Step size for incrementing/decrementing each component of the vector. Synchronized automatically when assigned.""" precision: int + """Number of decimal places to display for each component of the vector. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -1014,9 +1064,13 @@ class GuiVector2Message(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiVector3Props(GuiBaseProps): min: Optional[Tuple[float, float, float]] + """Minimum allowed values for each component of the vector. Synchronized automatically when assigned.""" max: Optional[Tuple[float, float, float]] + """Maximum allowed values for each component of the vector. Synchronized automatically when assigned.""" step: float + """Step size for incrementing/decrementing each component of the vector. Synchronized automatically when assigned.""" precision: int + """Number of decimal places to display for each component of the vector. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -1042,7 +1096,9 @@ class GuiTextMessage(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiDropdownProps(GuiBaseProps): + # This will actually be manually overridden for better types. options: Tuple[str, ...] + """Tuple of options for the dropdown. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -1056,6 +1112,7 @@ class GuiDropdownMessage(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiButtonGroupProps(GuiBaseProps): options: Tuple[str, ...] + """Tuple of buttons for the button group. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -1141,19 +1198,19 @@ class CatmullRomSplineMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class CatmullRomSplineProps: positions: Tuple[Tuple[float, float, float], ...] - """A tuple of 3D positions (x, y, z) defining the spline's path. For handles, synchronized automatically when assigned.""" + """A tuple of 3D positions (x, y, z) defining the spline's path. Synchronized automatically when assigned.""" curve_type: Literal["centripetal", "chordal", "catmullrom"] - """Type of the curve ('centripetal', 'chordal', 'catmullrom'). For handles, synchronized automatically when assigned.""" + """Type of the curve ('centripetal', 'chordal', 'catmullrom'). Synchronized automatically when assigned.""" tension: float - """Tension of the curve. Affects the tightness of the curve. For handles, synchronized automatically when assigned.""" + """Tension of the curve. Affects the tightness of the curve. Synchronized automatically when assigned.""" closed: bool - """Boolean indicating if the spline is closed (forms a loop). For handles, synchronized automatically when assigned.""" + """Boolean indicating if the spline is closed (forms a loop). Synchronized automatically when assigned.""" line_width: float - """Width of the spline line. For handles, synchronized automatically when assigned.""" + """Width of the spline line. Synchronized automatically when assigned.""" color: Tuple[int, int, int] - """Color of the spline as RGB integers. For handles, synchronized automatically when assigned.""" + """Color of the spline as RGB integers. Synchronized automatically when assigned.""" segments: Optional[int] - """Number of segments to divide the spline into. For handles, synchronized automatically when assigned.""" + """Number of segments to divide the spline into. Synchronized automatically when assigned.""" @dataclasses.dataclass @@ -1167,15 +1224,15 @@ class CubicBezierSplineMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class CubicBezierSplineProps: positions: Tuple[Tuple[float, float, float], ...] - """A tuple of 3D positions (x, y, z) defining the spline's key points. For handles, synchronized automatically when assigned.""" + """A tuple of 3D positions (x, y, z) defining the spline's key points. Synchronized automatically when assigned.""" control_points: Tuple[Tuple[float, float, float], ...] - """A tuple of control points for Bezier curve shaping. For handles, synchronized automatically when assigned.""" + """A tuple of control points for Bezier curve shaping. Synchronized automatically when assigned.""" line_width: float - """Width of the spline line. For handles, synchronized automatically when assigned.""" + """Width of the spline line. Synchronized automatically when assigned.""" color: Tuple[int, int, int] - """Color of the spline as RGB integers. For handles, synchronized automatically when assigned.""" + """Color of the spline as RGB integers. Synchronized automatically when assigned.""" segments: Optional[int] - """Number of segments to divide the spline into. For handles, synchronized automatically when assigned.""" + """Number of segments to divide the spline into. Synchronized automatically when assigned.""" @dataclasses.dataclass diff --git a/src/viser/client/src/WebsocketMessages.ts b/src/viser/client/src/WebsocketMessages.ts index 2a164b996..27b15ac35 100644 --- a/src/viser/client/src/WebsocketMessages.ts +++ b/src/viser/client/src/WebsocketMessages.ts @@ -543,346 +543,6 @@ export interface ResetSceneMessage { export interface ResetGuiMessage { type: "ResetGuiMessage"; } -/** Base message type containing fields commonly used by GUI inputs. - * - * (automatically generated) - */ -export interface GuiBaseProps { - type: "GuiBaseProps"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; -} -/** GuiButtonProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', color: 'Optional[Color]', icon_html: 'Optional[str]') - * - * (automatically generated) - */ -export interface GuiButtonProps { - type: "GuiButtonProps"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; - color: - | "dark" - | "gray" - | "red" - | "pink" - | "grape" - | "violet" - | "indigo" - | "blue" - | "cyan" - | "green" - | "lime" - | "yellow" - | "orange" - | "teal" - | null; - icon_html: string | null; -} -/** GuiUploadButtonProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', color: 'Optional[Color]', icon_html: 'Optional[str]', mime_type: 'str') - * - * (automatically generated) - */ -export interface GuiUploadButtonProps { - type: "GuiUploadButtonProps"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; - color: - | "dark" - | "gray" - | "red" - | "pink" - | "grape" - | "violet" - | "indigo" - | "blue" - | "cyan" - | "green" - | "lime" - | "yellow" - | "orange" - | "teal" - | null; - icon_html: string | null; - mime_type: string; -} -/** GuiSliderProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', min: 'float', max: 'float', step: 'Optional[float]', precision: 'int', _marks: 'Optional[Tuple[GuiSliderMark, ...]]' = None) - * - * (automatically generated) - */ -export interface GuiSliderProps { - type: "GuiSliderProps"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; - min: number; - max: number; - step: number | null; - precision: number; - _marks: { value: number; label: string | null }[] | null; -} -/** GuiMultiSliderProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', min: 'float', max: 'float', step: 'Optional[float]', min_range: 'Optional[float]', precision: 'int', fixed_endpoints: 'bool' = False, _marks: 'Optional[Tuple[GuiSliderMark, ...]]' = None) - * - * (automatically generated) - */ -export interface GuiMultiSliderProps { - type: "GuiMultiSliderProps"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; - min: number; - max: number; - step: number | null; - min_range: number | null; - precision: number; - fixed_endpoints: boolean; - _marks: { value: number; label: string | null }[] | null; -} -/** GuiNumberProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', precision: 'int', step: 'float', min: 'Optional[float]', max: 'Optional[float]') - * - * (automatically generated) - */ -export interface GuiNumberProps { - type: "GuiNumberProps"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; - precision: number; - step: number; - min: number | null; - max: number | null; -} -/** GuiRgbProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool') - * - * (automatically generated) - */ -export interface GuiRgbProps { - type: "GuiRgbProps"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; -} -/** GuiRgbaProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool') - * - * (automatically generated) - */ -export interface GuiRgbaProps { - type: "GuiRgbaProps"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; -} -/** GuiCheckboxProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool') - * - * (automatically generated) - */ -export interface GuiCheckboxProps { - type: "GuiCheckboxProps"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; -} -/** GuiVector2Props(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', min: 'Optional[Tuple[float, float]]', max: 'Optional[Tuple[float, float]]', step: 'float', precision: 'int') - * - * (automatically generated) - */ -export interface GuiVector2Props { - type: "GuiVector2Props"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; - min: [number, number] | null; - max: [number, number] | null; - step: number; - precision: number; -} -/** GuiVector3Props(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', min: 'Optional[Tuple[float, float, float]]', max: 'Optional[Tuple[float, float, float]]', step: 'float', precision: 'int') - * - * (automatically generated) - */ -export interface GuiVector3Props { - type: "GuiVector3Props"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; - min: [number, number, number] | null; - max: [number, number, number] | null; - step: number; - precision: number; -} -/** GuiTextProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool') - * - * (automatically generated) - */ -export interface GuiTextProps { - type: "GuiTextProps"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; -} -/** GuiDropdownProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', options: 'Tuple[str, ...]') - * - * (automatically generated) - */ -export interface GuiDropdownProps { - type: "GuiDropdownProps"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; - options: string[]; -} -/** Handle for a dropdown-style GUI input in our visualizer. - * - * Lets us get values, set values, and detect updates. - * - * (automatically generated) - */ -export interface GuiDropdownHandle { - type: "GuiDropdownHandle"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; - options: string[]; -} -/** GuiButtonGroupProps(order: 'float', label: 'str', hint: 'Optional[str]', visible: 'bool', disabled: 'bool', options: 'Tuple[str, ...]') - * - * (automatically generated) - */ -export interface GuiButtonGroupProps { - type: "GuiButtonGroupProps"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; - options: string[]; -} -export interface _GuiInputHandle { - type: "_GuiInputHandle"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; -} -/** A handle is created for each GUI element that is added in `viser`. - * Handles can be used to read and write state. - * - * When a GUI element is added via :attr:`ViserServer.gui`, state is - * synchronized between all connected clients. When a GUI element is added via - * :attr:`ClientHandle.gui`, state is local to a specific client. - * - * - * (automatically generated) - */ -export interface GuiInputHandle { - type: "GuiInputHandle"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; -} -/** Handle for a dropdown-style GUI input in our visualizer. - * - * Lets us get values, set values, and detect updates. - * - * (automatically generated) - */ -export interface GuiDropdownHandle { - type: "GuiDropdownHandle"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; - options: string[]; -} -/** Handle for a button input in our visualizer. - * - * Lets us detect clicks. - * - * (automatically generated) - */ -export interface GuiButtonHandle { - type: "GuiButtonHandle"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; -} -/** Handle for an upload file button in our visualizer. - * - * The `.value` attribute will be updated with the contents of uploaded files. - * - * - * (automatically generated) - */ -export interface GuiUploadButtonHandle { - type: "GuiUploadButtonHandle"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; -} -/** Handle for a button group input in our visualizer. - * - * Lets us detect clicks. - * - * (automatically generated) - */ -export interface GuiButtonGroupHandle { - type: "GuiButtonGroupHandle"; - order: number; - label: string; - hint: string | null; - visible: boolean; - disabled: boolean; -} -/** Use to remove markdown. - * - * (automatically generated) - */ -export interface GuiProgressBarHandle { - type: "GuiProgressBarHandle"; - order: number; - visible: boolean; - label: string; - hint: string | null; - disabled: boolean; -} /** GuiFolderMessage(id: 'str', container_id: 'str', props: 'GuiFolderProps') * * (automatically generated) @@ -937,6 +597,8 @@ export interface GuiProgressBarMessage { | "teal" | null; visible: boolean; + label: string; + value: number; }; } /** GuiPlotlyMessage(id: 'str', container_id: 'str', props: 'GuiPlotlyProps') @@ -1019,7 +681,7 @@ export interface GuiButtonMessage { | "orange" | "teal" | null; - icon_html: string | null; + _icon_html: string | null; }; } /** GuiUploadButtonMessage(id: 'str', container_id: 'str', props: 'GuiUploadButtonProps') @@ -1052,7 +714,7 @@ export interface GuiUploadButtonMessage { | "orange" | "teal" | null; - icon_html: string | null; + _icon_html: string | null; mime_type: string; }; } @@ -1073,7 +735,7 @@ export interface GuiSliderMessage { disabled: boolean; min: number; max: number; - step: number | null; + step: number; precision: number; _marks: { value: number; label: string | null }[] | null; }; @@ -1095,7 +757,7 @@ export interface GuiMultiSliderMessage { disabled: boolean; min: number; max: number; - step: number | null; + step: number; min_range: number | null; precision: number; fixed_endpoints: boolean; @@ -1505,28 +1167,6 @@ export type Message = | SceneNodeClickMessage | ResetSceneMessage | ResetGuiMessage - | GuiBaseProps - | GuiButtonProps - | GuiUploadButtonProps - | GuiSliderProps - | GuiMultiSliderProps - | GuiNumberProps - | GuiRgbProps - | GuiRgbaProps - | GuiCheckboxProps - | GuiVector2Props - | GuiVector3Props - | GuiTextProps - | GuiDropdownProps - | GuiDropdownHandle - | GuiButtonGroupProps - | _GuiInputHandle - | GuiInputHandle - | GuiDropdownHandle - | GuiButtonHandle - | GuiUploadButtonHandle - | GuiButtonGroupHandle - | GuiProgressBarHandle | GuiFolderMessage | GuiMarkdownMessage | GuiProgressBarMessage diff --git a/src/viser/client/src/components/Button.tsx b/src/viser/client/src/components/Button.tsx index 3e37aa120..caa924564 100644 --- a/src/viser/client/src/components/Button.tsx +++ b/src/viser/client/src/components/Button.tsx @@ -8,10 +8,9 @@ import { htmlIconWrapper } from "./ComponentStyles.css"; export default function ButtonComponent({ id, - props: { visible, disabled, label, ...otherProps }, + props: { visible, disabled, label, color, _icon_html: icon_html }, }: GuiButtonMessage) { const { messageSender } = React.useContext(GuiComponentContext)!; - const { color, icon_html } = otherProps; if (!(visible ?? true)) return <>; return ( diff --git a/src/viser/client/src/components/UploadButton.tsx b/src/viser/client/src/components/UploadButton.tsx index 34feb3390..2b2e39718 100644 --- a/src/viser/client/src/components/UploadButton.tsx +++ b/src/viser/client/src/components/UploadButton.tsx @@ -11,7 +11,7 @@ import { htmlIconWrapper } from "./ComponentStyles.css"; export default function UploadButtonComponent({ id, - props: { disabled, mime_type, color, icon_html, label }, + props: { disabled, mime_type, color, _icon_html: icon_html, label }, }: GuiUploadButtonMessage) { // Handle GUI input types. const viewer = useContext(ViewerContext)!; From abaa97ddd516c056927c2bd4553c91c22d40d2b7 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 24 Sep 2024 01:28:58 -0700 Subject: [PATCH 08/15] tsc --- src/viser/_messages.py | 2 +- src/viser/client/src/ControlPanel/GuiState.tsx | 7 +++++-- src/viser/client/src/WebsocketMessages.ts | 2 +- src/viser/client/src/components/ProgressBar.tsx | 5 +---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 45c8f2450..4accdcdd5 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -1137,7 +1137,7 @@ class GuiUpdateMessage(Message): id: str updates: Annotated[ Dict[str, Any], - infra.TypeScriptAnnotationOverride("Partial"), + infra.TypeScriptAnnotationOverride("{[key: string]: any}"), ] """Mapping from property name to new value.""" diff --git a/src/viser/client/src/ControlPanel/GuiState.tsx b/src/viser/client/src/ControlPanel/GuiState.tsx index e8a61a4ad..b7784ed1b 100644 --- a/src/viser/client/src/ControlPanel/GuiState.tsx +++ b/src/viser/client/src/ControlPanel/GuiState.tsx @@ -164,14 +164,17 @@ export function useGuiState(initialServer: string) { // Iterate over key/value pairs. for (const [key, value] of Object.entries(updates)) { + // We don't put `value` in the props object to make types + // stronger in the user-facing Python API. This results in some + // nastiness here, we should revisit... if (key === "value") { - state.guiConfigFromId[id].value = value; + (state.guiConfigFromId[id] as any).value = value; } else if (!(key in config.props)) { console.error( `Tried to update nonexistent property '${key}' of GUI element ${id}!`, ); } else { - state.guiConfigFromId[id].props[key] = value; + (state.guiConfigFromId[id].props as any)[key] = value; } } }); diff --git a/src/viser/client/src/WebsocketMessages.ts b/src/viser/client/src/WebsocketMessages.ts index 27b15ac35..31e6800db 100644 --- a/src/viser/client/src/WebsocketMessages.ts +++ b/src/viser/client/src/WebsocketMessages.ts @@ -946,7 +946,7 @@ export interface GuiRemoveMessage { export interface GuiUpdateMessage { type: "GuiUpdateMessage"; id: string; - updates: Partial; + updates: { [key: string]: any }; } /** Sent client<->server when any property of a scene node is changed. * diff --git a/src/viser/client/src/components/ProgressBar.tsx b/src/viser/client/src/components/ProgressBar.tsx index 944f4ecf6..f4d5655e9 100644 --- a/src/viser/client/src/components/ProgressBar.tsx +++ b/src/viser/client/src/components/ProgressBar.tsx @@ -2,10 +2,7 @@ import { Box, Progress } from "@mantine/core"; import { GuiProgressBarMessage } from "../WebsocketMessages"; export default function ProgressBarComponent({ - visible, - color, - value, - animated, + props: { visible, color, value, animated }, }: GuiProgressBarMessage) { if (!visible) return <>; return ( From 52fe6a965ffe1bb3e9721d4ce59ebef59fb01788 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 24 Sep 2024 01:31:40 -0700 Subject: [PATCH 09/15] Support typescript gen from Dict[K, V] --- src/viser/_messages.py | 12 +++--------- src/viser/infra/_typescript_interface_gen.py | 4 ++++ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 4accdcdd5..73af9dd87 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -9,7 +9,7 @@ import numpy as onp import numpy.typing as onpt -from typing_extensions import Annotated, Literal, override +from typing_extensions import Literal, override from . import infra, theme @@ -1135,10 +1135,7 @@ class GuiUpdateMessage(Message): """Sent client<->server when any property of a GUI component is changed.""" id: str - updates: Annotated[ - Dict[str, Any], - infra.TypeScriptAnnotationOverride("{[key: string]: any}"), - ] + updates: Dict[str, Any] """Mapping from property name to new value.""" @override @@ -1157,10 +1154,7 @@ class SceneNodeUpdateMessage(Message): """Sent client<->server when any property of a scene node is changed.""" name: str - updates: Annotated[ - Dict[str, Any], - infra.TypeScriptAnnotationOverride("{[key: string]: any}"), - ] + updates: Dict[str, Any] """Mapping from property name to new value.""" @override diff --git a/src/viser/infra/_typescript_interface_gen.py b/src/viser/infra/_typescript_interface_gen.py index 1a86d72b5..b0ea5f634 100644 --- a/src/viser/infra/_typescript_interface_gen.py +++ b/src/viser/infra/_typescript_interface_gen.py @@ -58,6 +58,10 @@ def _get_ts_type(typ: Type[Any]) -> str: args = get_args(typ) assert len(args) == 1 return _get_ts_type(args[0]) + "[]" + elif origin_typ is dict: + args = get_args(typ) + assert len(args) == 2 + return "{[key: " + _get_ts_type(args[0]) + "]: " + _get_ts_type(args[1]) + "}" elif origin_typ in (Literal, LiteralAlt): return " | ".join( map( From fda295f69f3c00e399d19f0d979e035b05812f87 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 24 Sep 2024 01:37:09 -0700 Subject: [PATCH 10/15] Fix progress bars --- src/viser/_messages.py | 4 ---- src/viser/client/src/WebsocketMessages.ts | 2 -- src/viser/client/src/components/ProgressBar.tsx | 3 ++- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 73af9dd87..2c9b1fbc1 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -839,10 +839,6 @@ class GuiProgressBarProps: """Color of the progress bar. Synchronized automatically when assigned.""" visible: bool """Visibility state of the progress bar. Synchronized automatically when assigned.""" - label: str - """Label text for the progress bar. Synchronized automatically when assigned.""" - value: float - """Current value of the progress bar. Synchronized automatically when assigned.""" @dataclasses.dataclass diff --git a/src/viser/client/src/WebsocketMessages.ts b/src/viser/client/src/WebsocketMessages.ts index 31e6800db..742ad688a 100644 --- a/src/viser/client/src/WebsocketMessages.ts +++ b/src/viser/client/src/WebsocketMessages.ts @@ -597,8 +597,6 @@ export interface GuiProgressBarMessage { | "teal" | null; visible: boolean; - label: string; - value: number; }; } /** GuiPlotlyMessage(id: 'str', container_id: 'str', props: 'GuiPlotlyProps') diff --git a/src/viser/client/src/components/ProgressBar.tsx b/src/viser/client/src/components/ProgressBar.tsx index f4d5655e9..ac315793b 100644 --- a/src/viser/client/src/components/ProgressBar.tsx +++ b/src/viser/client/src/components/ProgressBar.tsx @@ -2,7 +2,8 @@ import { Box, Progress } from "@mantine/core"; import { GuiProgressBarMessage } from "../WebsocketMessages"; export default function ProgressBarComponent({ - props: { visible, color, value, animated }, + value, + props: { visible, color, animated }, }: GuiProgressBarMessage) { if (!visible) return <>; return ( From 35b0a8af1266313146c35ed96d65e1d94e43b7bb Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 24 Sep 2024 02:27:02 -0700 Subject: [PATCH 11/15] Docs --- docs/source/gui_handles.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/gui_handles.md b/docs/source/gui_handles.md index 16d33f2c0..c65d46acc 100644 --- a/docs/source/gui_handles.md +++ b/docs/source/gui_handles.md @@ -22,8 +22,6 @@ .. autoclass:: viser.GuiCheckboxHandle() -.. autoclass:: viser.GuiEvent() - .. autoclass:: viser.GuiMultiSliderHandle() .. autoclass:: viser.GuiNumberHandle() From 7ac568d791473c255a2e171a5b3ca377617fb09d Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 24 Sep 2024 03:35:03 -0700 Subject: [PATCH 12/15] Docs cleanup --- docs/requirements.txt | 4 +- docs/source/conf.py | 2 +- docs/source/gui_handles.md | 4 + src/viser/__init__.py | 2 + src/viser/_gui_api.py | 8 +- src/viser/_gui_handles.py | 274 ++++++++++++------- src/viser/_messages.py | 12 +- src/viser/client/src/WebsocketMessages.ts | 6 +- src/viser/client/src/components/TabGroup.tsx | 7 +- 9 files changed, 206 insertions(+), 113 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 98a72fdf4..a88f2c7b3 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,5 @@ -sphinx==7.2.6 -furo==2023.9.10 +sphinx==8.0.2 +furo==2024.8.6 docutils==0.20.1 m2r2==0.3.3.post2 toml==0.10.2 diff --git a/docs/source/conf.py b/docs/source/conf.py index 54393a36e..c251c2595 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -267,6 +267,6 @@ def setup(app): napoleon_use_ivar = False napoleon_use_param = True napoleon_use_rtype = True -napoleon_preprocess_types = False +napoleon_preprocess_types = True napoleon_type_aliases = None napoleon_attr_annotations = True diff --git a/docs/source/gui_handles.md b/docs/source/gui_handles.md index c65d46acc..9cd30e41e 100644 --- a/docs/source/gui_handles.md +++ b/docs/source/gui_handles.md @@ -34,6 +34,10 @@ .. autoclass:: viser.GuiTextHandle() +.. autoclass:: viser.GuiUploadButtonHandle() + +.. autoclass:: viser.UploadedFile() + .. autoclass:: viser.GuiVector2Handle() .. autoclass:: viser.GuiVector3Handle() diff --git a/src/viser/__init__.py b/src/viser/__init__.py index 168da8e99..014d7c764 100644 --- a/src/viser/__init__.py +++ b/src/viser/__init__.py @@ -16,8 +16,10 @@ from ._gui_handles import GuiTabGroupHandle as GuiTabGroupHandle from ._gui_handles import GuiTabHandle as GuiTabHandle from ._gui_handles import GuiTextHandle as GuiTextHandle +from ._gui_handles import GuiUploadButtonHandle as GuiUploadButtonHandle from ._gui_handles import GuiVector2Handle as GuiVector2Handle from ._gui_handles import GuiVector3Handle as GuiVector3Handle +from ._gui_handles import UploadedFile as UploadedFile from ._icons_enum import Icon as Icon from ._icons_enum import IconName as IconName from ._notification_handle import NotificationHandle as NotificationHandle diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index c2c00f845..6c3a36186 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -547,10 +547,10 @@ def add_tab_group( container_id=self._get_container_id(), props=_messages.GuiTabGroupProps( order=order, - tab_labels=(), + _tab_labels=(), visible=visible, - tab_icons_html=(), - tab_container_ids=(), + _tab_icons_html=(), + _tab_container_ids=(), ), ) ) @@ -1354,7 +1354,7 @@ def add_multi_slider( visible: bool = True, hint: str | None = None, order: float | None = None, - ) -> GuiMultiSliderHandle[tuple[IntOrFloat, ...]]: + ) -> GuiMultiSliderHandle[IntOrFloat]: """Add a multi slider to the GUI. Types of the min, max, step, and initial value should match. Args: diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 0bc96e82b..41904278a 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -44,6 +44,7 @@ GuiRgbaProps, GuiRgbProps, GuiSliderProps, + GuiTabGroupProps, GuiTextProps, GuiUpdateMessage, GuiVector2Props, @@ -172,7 +173,13 @@ class _GuiInputHandle( ): @property def value(self) -> T: - """Value of the GUI input. Synchronized automatically when assigned.""" + """Value of the GUI input. Synchronized automatically when assigned. + + :meta private: + """ + # ^Note: we mark this property as private for Sphinx because I haven't + # been able to get it to resolve the TypeVar in a readable way. + # For the documentation's sake, we'll be manually adding ::attribute directives below. return self._impl.value @value.setter @@ -236,34 +243,99 @@ def on_update( return func -class GuiCheckboxHandle(GuiInputHandle[bool], GuiCheckboxProps): ... +class GuiCheckboxHandle(GuiInputHandle[bool], GuiCheckboxProps): + """Handle for checkbox inputs. + + .. attribute:: value + :type: bool + + Value of the input. Synchronized automatically when assigned. + """ + +class GuiTextHandle(GuiInputHandle[str], GuiTextProps): + """Handle for text inputs. -class GuiTextHandle(GuiInputHandle[str], GuiTextProps): ... + .. attribute:: value + :type: str + + Value of the input. Synchronized automatically when assigned. + """ IntOrFloat = TypeVar("IntOrFloat", int, float) -class GuiNumberHandle(GuiInputHandle[T], Generic[T], GuiNumberProps): ... +class GuiNumberHandle(GuiInputHandle[IntOrFloat], Generic[IntOrFloat], GuiNumberProps): + """Handle for number inputs. + + .. attribute:: value + :type: IntOrFloat + + Value of the input. Synchronized automatically when assigned. + """ + + +class GuiSliderHandle(GuiInputHandle[IntOrFloat], Generic[IntOrFloat], GuiSliderProps): + """Handle for slider inputs. + + .. attribute:: value + :type: IntOrFloat + + Value of the input. Synchronized automatically when assigned. + """ + + +class GuiMultiSliderHandle( + GuiInputHandle[Tuple[IntOrFloat, ...]], Generic[IntOrFloat], GuiMultiSliderProps +): + """Handle for multi-slider inputs. + .. attribute:: value + :type: tuple[IntOrFloat, ...] + + Value of the input. Synchronized automatically when assigned. + """ -class GuiSliderHandle(GuiInputHandle[T], Generic[T], GuiSliderProps): ... +class GuiRgbHandle(GuiInputHandle[Tuple[int, int, int]], GuiRgbProps): + """Handle for RGB color inputs. -class GuiMultiSliderHandle(GuiInputHandle[T], Generic[T], GuiMultiSliderProps): ... + .. attribute:: value + :type: tuple[int, int, int] + Value of the input. Synchronized automatically when assigned. + """ -class GuiRgbHandle(GuiInputHandle[Tuple[int, int, int]], GuiRgbProps): ... +class GuiRgbaHandle(GuiInputHandle[Tuple[int, int, int, int]], GuiRgbaProps): + """Handle for RGBA color inputs. -class GuiRgbaHandle(GuiInputHandle[Tuple[int, int, int, int]], GuiRgbaProps): ... + .. attribute:: value + :type: tuple[int, int, int, int] + Value of the input. Synchronized automatically when assigned. + """ -class GuiVector2Handle(GuiInputHandle[Tuple[float, float]], GuiVector2Props): ... +class GuiVector2Handle(GuiInputHandle[Tuple[float, float]], GuiVector2Props): + """Handle for 2D vector inputs. -class GuiVector3Handle(GuiInputHandle[Tuple[float, float, float]], GuiVector3Props): ... + .. attribute:: value + :type: tuple[float, float] + + Value of the input. Synchronized automatically when assigned. + """ + + +class GuiVector3Handle(GuiInputHandle[Tuple[float, float, float]], GuiVector3Props): + """Handle for 3D vector inputs. + + .. attribute:: value + :type: tuple[float, float, float] + + Value of the input. Synchronized automatically when assigned. + """ @dataclasses.dataclass(frozen=True) @@ -283,7 +355,11 @@ class GuiEvent(Generic[TGuiHandle]): class GuiButtonHandle(_GuiInputHandle[bool]): """Handle for a button input in our visualizer. - Lets us detect clicks.""" + .. attribute:: value + :type: bool + + Value of the button. Set to `True` when the button is pressed. Can be manually set back to `False`. + """ def on_click( self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None] @@ -307,6 +383,11 @@ class GuiUploadButtonHandle(_GuiInputHandle[UploadedFile]): """Handle for an upload file button in our visualizer. The `.value` attribute will be updated with the contents of uploaded files. + + .. attribute:: value + :type: UploadedFile + + Value of the input. Contains information about the uploaded file. """ def on_upload( @@ -320,7 +401,11 @@ def on_upload( class GuiButtonGroupHandle(_GuiInputHandle[str], GuiButtonGroupProps): """Handle for a button group input in our visualizer. - Lets us detect clicks.""" + .. attribute:: value + :type: str + + Value of the input. Represents the currently selected button in the group. + """ def on_click( self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None] @@ -345,7 +430,11 @@ class GuiDropdownHandle( ): """Handle for a dropdown-style GUI input in our visualizer. - Lets us get values, set values, and detect updates.""" + .. attribute:: value + :type: StringType + + Value of the input. Represents the currently selected option in the dropdown. + """ @property def options(self) -> tuple[StringType, ...]: @@ -383,22 +472,12 @@ def options(self, options: Iterable[StringType]) -> None: # type: ignore ) -@dataclasses.dataclass(frozen=True) -class GuiTabGroupHandle: +class GuiTabGroupHandle(_GuiHandle[None], GuiTabGroupProps): """Handle for a tab group. Call :meth:`add_tab()` to add a tab.""" - _tab_group_id: str - _labels: list[str] - _icons_html: list[str | None] - _tabs: list[GuiTabHandle] - _gui_api: GuiApi - _parent_container_id: str - _order: float - - @property - def order(self) -> float: - """Read-only order value, which dictates the position of the GUI element.""" - return self._order + def __init__(self, _impl: _GuiHandleState[None]) -> None: + super().__init__(_impl=_impl) + self._tab_handles: list[GuiTabHandle] = [] def add_tab(self, label: str, icon: IconName | None = None) -> GuiTabHandle: """Add a tab. Returns a handle we can use to add GUI elements to it.""" @@ -408,38 +487,82 @@ def add_tab(self, label: str, icon: IconName | None = None) -> GuiTabHandle: # We may want to make this thread-safe in the future. out = GuiTabHandle(_parent=self, _id=id) - self._labels.append(label) - self._icons_html.append(None if icon is None else svg_from_icon(icon)) - self._tabs.append(out) - - self._sync_with_client() + self._tab_handles.append(out) + self._tab_labels = self._tab_labels + (label,) + self._icons_html = self._icons_html + ( + None if icon is None else svg_from_icon(icon), + ) + self._tab_container_ids = tuple(handle._id for handle in self._tab_handles) return out def __post_init__(self) -> None: - parent = self._gui_api._container_handle_from_id[self._parent_container_id] - parent._children[self._tab_group_id] = self + parent = self._impl.gui_api._container_handle_from_id[ + self._impl.parent_container_id + ] + parent._children[self._impl.id] = self def remove(self) -> None: """Remove this tab group and all contained GUI elements.""" - for tab in tuple(self._tabs): + for tab in tuple(self._tab_handles): tab.remove() - gui_api = self._gui_api - gui_api._websock_interface.queue_message(GuiRemoveMessage(self._tab_group_id)) - parent = gui_api._container_handle_from_id[self._parent_container_id] - parent._children.pop(self._tab_group_id) + gui_api = self._impl.gui_api + gui_api._websock_interface.queue_message(GuiRemoveMessage(self._impl.id)) + parent = gui_api._container_handle_from_id[self._impl.parent_container_id] + parent._children.pop(self._impl.id) - def _sync_with_client(self) -> None: - """Send messages for syncing tab state with the client.""" - self._gui_api._websock_interface.queue_message( - GuiUpdateMessage( - self._tab_group_id, - { - "tab_labels": tuple(self._labels), - "tab_icons_html": tuple(self._icons_html), - "tab_container_ids": tuple(tab._id for tab in self._tabs), - }, - ) + +@dataclasses.dataclass +class GuiTabHandle: + """Use as a context to place GUI elements into a tab.""" + + _parent: GuiTabGroupHandle + _id: str # Used as container ID of children. + _container_id_restore: str | None = None + _children: dict[str, SupportsRemoveProtocol] = dataclasses.field( + default_factory=dict + ) + + def __enter__(self) -> GuiTabHandle: + self._container_id_restore = self._parent._impl.gui_api._get_container_id() + self._parent._impl.gui_api._set_container_id(self._id) + return self + + def __exit__(self, *args) -> None: + del args + assert self._container_id_restore is not None + self._parent._impl.gui_api._set_container_id(self._container_id_restore) + self._container_id_restore = None + + def __post_init__(self) -> None: + self._parent._impl.gui_api._container_handle_from_id[self._id] = self + + def remove(self) -> None: + """Permanently remove this tab and all contained GUI elements from the + visualizer.""" + # We may want to make this thread-safe in the future. + found_index = -1 + for i, tab in enumerate(self._parent._tab_handles): + if tab is self: + found_index = i + break + assert found_index != -1, "Tab already removed!" + + self._parent._tab_labels = ( + self._parent._tab_labels[:found_index] + + self._parent._tab_labels[found_index + 1 :] + ) + self._parent._icons_html = ( + self._parent._icons_html[:found_index] + + self._parent._icons_html[found_index + 1 :] ) + self._parent._tab_handles = ( + self._parent._tab_handles[:found_index] + + self._parent._tab_handles[found_index + 1 :] + ) + + for child in tuple(self._children.values()): + child.remove() + self._parent._impl.gui_api._container_handle_from_id.pop(self._id) class GuiFolderHandle(_GuiHandle, GuiFolderProps): @@ -515,51 +638,6 @@ def close(self) -> None: self._gui_api._container_handle_from_id.pop(self._id) -@dataclasses.dataclass -class GuiTabHandle: - """Use as a context to place GUI elements into a tab.""" - - _parent: GuiTabGroupHandle - _id: str # Used as container ID of children. - _container_id_restore: str | None = None - _children: dict[str, SupportsRemoveProtocol] = dataclasses.field( - default_factory=dict - ) - - def __enter__(self) -> GuiTabHandle: - self._container_id_restore = self._parent._gui_api._get_container_id() - self._parent._gui_api._set_container_id(self._id) - return self - - def __exit__(self, *args) -> None: - del args - assert self._container_id_restore is not None - self._parent._gui_api._set_container_id(self._container_id_restore) - self._container_id_restore = None - - def __post_init__(self) -> None: - self._parent._gui_api._container_handle_from_id[self._id] = self - - def remove(self) -> None: - """Permanently remove this tab and all contained GUI elements from the - visualizer.""" - # We may want to make this thread-safe in the future. - container_index = -1 - for i, tab in enumerate(self._parent._tabs): - if tab is self: - container_index = i - break - assert container_index != -1, "Tab already removed!" - - self._parent._labels.pop(container_index) - self._parent._icons_html.pop(container_index) - self._parent._tabs.pop(container_index) - self._parent._sync_with_client() - for child in tuple(self._children.values()): - child.remove() - self._parent._gui_api._container_handle_from_id.pop(self._id) - - def _get_data_url(url: str, image_root: Path | None) -> str: if not url.startswith("http") and not image_root: warnings.warn( @@ -598,11 +676,11 @@ def _parse_markdown(markdown: str, image_root: Path | None) -> str: class GuiProgressBarHandle(_GuiInputHandle[float], GuiProgressBarProps): - """Use to remove markdown.""" + """Handle for updating and removing progress bars.""" class GuiMarkdownHandle(_GuiHandle[None], GuiMarkdownProps): - """Use to remove markdown.""" + """Handling for updating and removing markdown elements.""" def __init__(self, _impl: _GuiHandleState, _content: str, _image_root: Path | None): super().__init__(_impl=_impl) @@ -622,7 +700,7 @@ def content(self, content: str) -> None: class GuiPlotlyHandle(_GuiHandle[None], GuiPlotlyProps): - """Use to update or remove markdown elements.""" + """Handle for updating and removing Plotly figures.""" def __init__(self, _impl: _GuiHandleState, _figure: go.Figure): super().__init__(_impl=_impl) diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 2c9b1fbc1..bd3eb951c 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -870,12 +870,16 @@ class GuiPlotlyMessage(Message, tag="GuiComponentMessage"): @dataclasses.dataclass class GuiTabGroupProps: - # Note: for tab groups we currently don't expose properties automatically. - tab_labels: Tuple[str, ...] - tab_icons_html: Tuple[Union[str, None], ...] - tab_container_ids: Tuple[str, ...] + _tab_labels: Tuple[str, ...] + """(Private) Tuple of labels for each tab. Synchronized automatically when assigned.""" + _tab_icons_html: Tuple[Union[str, None], ...] + """(Private) Tuple of HTML strings for icons of each tab, or None if no icon. Synchronized automatically when assigned.""" + _tab_container_ids: Tuple[str, ...] + """(Private) Tuple of container IDs for each tab. Synchronized automatically when assigned.""" order: float + """Order value for arranging GUI elements. Synchronized automatically when assigned.""" visible: bool + """Visibility state of the tab group. Synchronized automatically when assigned.""" @dataclasses.dataclass diff --git a/src/viser/client/src/WebsocketMessages.ts b/src/viser/client/src/WebsocketMessages.ts index 742ad688a..1518a40c2 100644 --- a/src/viser/client/src/WebsocketMessages.ts +++ b/src/viser/client/src/WebsocketMessages.ts @@ -623,9 +623,9 @@ export interface GuiTabGroupMessage { id: string; container_id: string; props: { - tab_labels: string[]; - tab_icons_html: (string | null)[]; - tab_container_ids: string[]; + _tab_labels: string[]; + _tab_icons_html: (string | null)[]; + _tab_container_ids: string[]; order: number; visible: boolean; }; diff --git a/src/viser/client/src/components/TabGroup.tsx b/src/viser/client/src/components/TabGroup.tsx index 863546883..22d2f7129 100644 --- a/src/viser/client/src/components/TabGroup.tsx +++ b/src/viser/client/src/components/TabGroup.tsx @@ -5,7 +5,12 @@ import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { htmlIconWrapper } from "./ComponentStyles.css"; export default function TabGroupComponent({ - props: { tab_labels, tab_icons_html, tab_container_ids, visible }, + props: { + _tab_labels: tab_labels, + _tab_icons_html: tab_icons_html, + _tab_container_ids: tab_container_ids, + visible, + }, }: GuiTabGroupMessage) { const { GuiContainer } = React.useContext(GuiComponentContext)!; if (!visible) return <>; From 05308dd20f1f05b1e7b8fabbfd5a3595ea5b5806 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 24 Sep 2024 03:42:19 -0700 Subject: [PATCH 13/15] onp => np --- docs/source/examples/01_image.rst | 9 +- docs/source/examples/02_gui.rst | 16 +- docs/source/examples/03_gui_callbacks.rst | 6 +- docs/source/examples/05_camera_commands.rst | 8 +- docs/source/examples/06_mesh.rst | 6 +- .../examples/07_record3d_visualizer.rst | 6 +- docs/source/examples/08_smpl_visualizer.rst | 9 +- docs/source/examples/09_urdf_visualizer.rst | 10 +- docs/source/examples/11_colmap_visualizer.rst | 14 +- docs/source/examples/15_gui_in_scene.rst | 10 +- .../examples/17_background_composite.rst | 6 +- docs/source/examples/18_splines.rst | 10 +- docs/source/examples/19_get_renders.rst | 6 +- docs/source/examples/20_scene_pointer.rst | 24 +- docs/source/examples/22_games.rst | 16 +- docs/source/examples/23_plotly.rst | 6 +- .../examples/25_smpl_visualizer_skinned.rst | 14 +- docs/source/examples/26_lighting.rst | 6 +- examples/01_image.py | 9 +- examples/02_gui.py | 16 +- examples/03_gui_callbacks.py | 6 +- examples/05_camera_commands.py | 8 +- examples/06_mesh.py | 6 +- examples/07_record3d_visualizer.py | 6 +- examples/08_smpl_visualizer.py | 9 +- examples/09_urdf_visualizer.py | 10 +- examples/11_colmap_visualizer.py | 14 +- examples/15_gui_in_scene.py | 10 +- examples/17_background_composite.py | 6 +- examples/18_splines.py | 10 +- examples/19_get_renders.py | 6 +- examples/20_scene_pointer.py | 24 +- examples/22_games.py | 16 +- examples/23_plotly.py | 6 +- examples/25_smpl_visualizer_skinned.py | 11 +- examples/26_lighting.py | 6 +- examples/experimental/gaussian_splats.py | 42 +-- src/viser/_gui_api.py | 28 +- src/viser/_gui_handles.py | 6 +- src/viser/_messages.py | 30 +- src/viser/_scene_api.py | 256 +++++++++--------- src/viser/_scene_handles.py | 60 ++-- src/viser/_viser.py | 76 +++--- src/viser/extras/_record3d.py | 25 +- src/viser/extras/_urdf.py | 6 +- src/viser/infra/_messages.py | 12 +- src/viser/infra/_typescript_interface_gen.py | 4 +- src/viser/transforms/_base.py | 46 ++-- src/viser/transforms/_se2.py | 106 ++++---- src/viser/transforms/_se3.py | 110 ++++---- src/viser/transforms/_so2.py | 64 ++--- src/viser/transforms/hints/__init__.py | 6 +- tests/test_transforms_axioms.py | 18 +- tests/test_transforms_bijective.py | 54 ++-- tests/test_transforms_ops.py | 22 +- tests/utils.py | 44 ++- 56 files changed, 678 insertions(+), 698 deletions(-) diff --git a/docs/source/examples/01_image.rst b/docs/source/examples/01_image.rst index 217affdc6..8ae4d34f0 100644 --- a/docs/source/examples/01_image.rst +++ b/docs/source/examples/01_image.rst @@ -20,7 +20,7 @@ NeRFs), or images to render as 3D textures. from pathlib import Path import imageio.v3 as iio - import numpy as onp + import numpy as np import viser @@ -47,12 +47,7 @@ NeRFs), or images to render as 3D textures. while True: server.scene.add_image( "/noise", - onp.random.randint( - 0, - 256, - size=(400, 400, 3), - dtype=onp.uint8, - ), + np.random.randint(0, 256, size=(400, 400, 3), dtype=np.uint8), 4.0, 4.0, format="jpeg", diff --git a/docs/source/examples/02_gui.rst b/docs/source/examples/02_gui.rst index 819850547..754affd62 100644 --- a/docs/source/examples/02_gui.rst +++ b/docs/source/examples/02_gui.rst @@ -15,7 +15,7 @@ Examples of basic GUI elements that we can create, read from, and write to. import time - import numpy as onp + import numpy as np import viser @@ -96,8 +96,8 @@ Examples of basic GUI elements that we can create, read from, and write to. print(file.name, len(file.content), "bytes") # Pre-generate a point cloud to send. - point_positions = onp.random.uniform(low=-1.0, high=1.0, size=(5000, 3)) - color_coeffs = onp.random.uniform(0.4, 1.0, size=(point_positions.shape[0])) + point_positions = np.random.uniform(low=-1.0, high=1.0, size=(5000, 3)) + color_coeffs = np.random.uniform(0.4, 1.0, size=(point_positions.shape[0])) counter = 0 while True: @@ -111,11 +111,11 @@ Examples of basic GUI elements that we can create, read from, and write to. # connected clients. server.scene.add_point_cloud( "/point_cloud", - points=point_positions * onp.array(gui_vector3.value, dtype=onp.float32), + points=point_positions * np.array(gui_vector3.value, dtype=np.float32), colors=( - onp.tile(gui_rgb.value, point_positions.shape[0]).reshape((-1, 3)) + np.tile(gui_rgb.value, point_positions.shape[0]).reshape((-1, 3)) * color_coeffs[:, None] - ).astype(onp.uint8), + ).astype(np.uint8), position=gui_vector2.value + (0,), point_shape="circle", ) @@ -131,8 +131,8 @@ Examples of basic GUI elements that we can create, read from, and write to. # Update the number of handles in the multi-slider. if gui_slider_positions.value != len(gui_multi_slider.value): - gui_multi_slider.value = onp.linspace( - 0, 100, gui_slider_positions.value, dtype=onp.int64 + gui_multi_slider.value = np.linspace( + 0, 100, gui_slider_positions.value, dtype=np.int64 ) counter += 1 diff --git a/docs/source/examples/03_gui_callbacks.rst b/docs/source/examples/03_gui_callbacks.rst index 8f31eb295..6fea9c47b 100644 --- a/docs/source/examples/03_gui_callbacks.rst +++ b/docs/source/examples/03_gui_callbacks.rst @@ -16,7 +16,7 @@ we get updates. import time - import numpy as onp + import numpy as np from typing_extensions import assert_never import viser @@ -86,8 +86,8 @@ we get updates. num_points = gui_num_points.value server.scene.add_point_cloud( "/frame/point_cloud", - points=onp.random.normal(size=(num_points, 3)), - colors=onp.random.randint(0, 256, size=(num_points, 3)), + points=np.random.normal(size=(num_points, 3)), + colors=np.random.randint(0, 256, size=(num_points, 3)), ) # We can (optionally) also attach callbacks! diff --git a/docs/source/examples/05_camera_commands.rst b/docs/source/examples/05_camera_commands.rst index 36245f245..0ab96f0cc 100644 --- a/docs/source/examples/05_camera_commands.rst +++ b/docs/source/examples/05_camera_commands.rst @@ -16,7 +16,7 @@ corresponding client automatically. import time - import numpy as onp + import numpy as np import viser import viser.transforms as tf @@ -32,12 +32,12 @@ corresponding client automatically. When a frame is clicked, we move the camera to the corresponding frame. """ - rng = onp.random.default_rng(0) + rng = np.random.default_rng(0) def make_frame(i: int) -> None: # Sample a random orientation + position. wxyz = rng.normal(size=4) - wxyz /= onp.linalg.norm(wxyz) + wxyz /= np.linalg.norm(wxyz) position = rng.uniform(-3.0, 3.0, size=(3,)) # Create a coordinate frame and label. @@ -52,7 +52,7 @@ corresponding client automatically. ) T_world_target = tf.SE3.from_rotation_and_translation( tf.SO3(frame.wxyz), frame.position - ) @ tf.SE3.from_translation(onp.array([0.0, 0.0, -0.5])) + ) @ tf.SE3.from_translation(np.array([0.0, 0.0, -0.5])) T_current_target = T_world_current.inverse() @ T_world_target diff --git a/docs/source/examples/06_mesh.rst b/docs/source/examples/06_mesh.rst index a8114e67b..d94f9d7c4 100644 --- a/docs/source/examples/06_mesh.rst +++ b/docs/source/examples/06_mesh.rst @@ -16,7 +16,7 @@ Visualize a mesh. To get the demo data, see ``./assets/download_dragon_mesh.sh`` import time from pathlib import Path - import numpy as onp + import numpy as np import trimesh import viser @@ -35,13 +35,13 @@ Visualize a mesh. To get the demo data, see ``./assets/download_dragon_mesh.sh`` name="/simple", vertices=vertices, faces=faces, - wxyz=tf.SO3.from_x_radians(onp.pi / 2).wxyz, + wxyz=tf.SO3.from_x_radians(np.pi / 2).wxyz, position=(0.0, 0.0, 0.0), ) server.scene.add_mesh_trimesh( name="/trimesh", mesh=mesh.smoothed(), - wxyz=tf.SO3.from_x_radians(onp.pi / 2).wxyz, + wxyz=tf.SO3.from_x_radians(np.pi / 2).wxyz, position=(0.0, 5.0, 0.0), ) diff --git a/docs/source/examples/07_record3d_visualizer.rst b/docs/source/examples/07_record3d_visualizer.rst index a075b6682..06bf6932c 100644 --- a/docs/source/examples/07_record3d_visualizer.rst +++ b/docs/source/examples/07_record3d_visualizer.rst @@ -16,7 +16,7 @@ Parse and stream record3d captures. To get the demo data, see ``./assets/downloa import time from pathlib import Path - import numpy as onp + import numpy as np import tyro from tqdm.auto import tqdm @@ -96,7 +96,7 @@ Parse and stream record3d captures. To get the demo data, see ``./assets/downloa # Load in frames. server.scene.add_frame( "/frames", - wxyz=tf.SO3.exp(onp.array([onp.pi / 2.0, 0.0, 0.0])).wxyz, + wxyz=tf.SO3.exp(np.array([np.pi / 2.0, 0.0, 0.0])).wxyz, position=(0, 0, 0), show_axes=False, ) @@ -118,7 +118,7 @@ Parse and stream record3d captures. To get the demo data, see ``./assets/downloa ) # Place the frustum. - fov = 2 * onp.arctan2(frame.rgb.shape[0] / 2, frame.K[0, 0]) + fov = 2 * np.arctan2(frame.rgb.shape[0] / 2, frame.K[0, 0]) aspect = frame.rgb.shape[1] / frame.rgb.shape[0] server.scene.add_camera_frustum( f"/frames/t{i}/frustum", diff --git a/docs/source/examples/08_smpl_visualizer.rst b/docs/source/examples/08_smpl_visualizer.rst index e11fe1f7d..c390d42da 100644 --- a/docs/source/examples/08_smpl_visualizer.rst +++ b/docs/source/examples/08_smpl_visualizer.rst @@ -23,7 +23,6 @@ See here for download instructions: from pathlib import Path import numpy as np - import numpy as onp import tyro import viser @@ -43,7 +42,7 @@ See here for download instructions: def __init__(self, model_path: Path) -> None: assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!" - body_dict = dict(**onp.load(model_path, allow_pickle=True)) + body_dict = dict(**np.load(model_path, allow_pickle=True)) self._J_regressor = body_dict["J_regressor"] self._weights = body_dict["weights"] @@ -180,7 +179,7 @@ See here for download instructions: @gui_random_shape.on_click def _(_): for beta in gui_betas: - beta.value = onp.random.normal(loc=0.0, scale=1.0) + beta.value = np.random.normal(loc=0.0, scale=1.0) gui_betas = [] for i in range(num_betas): @@ -205,8 +204,8 @@ See here for download instructions: for joint in gui_joints: # It's hard to uniformly sample orientations directly in so(3), so we # first sample on S^3 and then convert. - quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,)) - quat /= onp.linalg.norm(quat) + quat = np.random.normal(loc=0.0, scale=1.0, size=(4,)) + quat /= np.linalg.norm(quat) joint.value = tf.SO3(wxyz=quat).log() gui_joints: list[viser.GuiInputHandle[tuple[float, float, float]]] = [] diff --git a/docs/source/examples/09_urdf_visualizer.rst b/docs/source/examples/09_urdf_visualizer.rst index 413c66256..4a72a4e63 100644 --- a/docs/source/examples/09_urdf_visualizer.rst +++ b/docs/source/examples/09_urdf_visualizer.rst @@ -25,7 +25,7 @@ and viser. It can also take a path to a local URDF file as input. import time from typing import Literal - import numpy as onp + import numpy as np import tyro from robot_descriptions.loaders.yourdfpy import load_robot_description @@ -44,8 +44,8 @@ and viser. It can also take a path to a local URDF file as input. lower, upper, ) in viser_urdf.get_actuated_joint_limits().items(): - lower = lower if lower is not None else -onp.pi - upper = upper if upper is not None else onp.pi + lower = lower if lower is not None else -np.pi + upper = upper if upper is not None else np.pi initial_pos = 0.0 if lower < 0 and upper > 0 else (lower + upper) / 2.0 slider = server.gui.add_slider( label=joint_name, @@ -56,7 +56,7 @@ and viser. It can also take a path to a local URDF file as input. ) slider.on_update( # When sliders move, we update the URDF configuration. lambda _: viser_urdf.update_cfg( - onp.array([slider.value for slider in slider_handles]) + np.array([slider.value for slider in slider_handles]) ) ) slider_handles.append(slider) @@ -97,7 +97,7 @@ and viser. It can also take a path to a local URDF file as input. ) # Set initial robot configuration. - viser_urdf.update_cfg(onp.array(initial_config)) + viser_urdf.update_cfg(np.array(initial_config)) # Create joint reset button. reset_button = server.gui.add_button("Reset") diff --git a/docs/source/examples/11_colmap_visualizer.rst b/docs/source/examples/11_colmap_visualizer.rst index 943f343f1..f1ed575c9 100644 --- a/docs/source/examples/11_colmap_visualizer.rst +++ b/docs/source/examples/11_colmap_visualizer.rst @@ -19,7 +19,7 @@ Visualize COLMAP sparse reconstruction outputs. To get demo data, see ``./assets from typing import List import imageio.v3 as iio - import numpy as onp + import numpy as np import tyro from tqdm.auto import tqdm @@ -60,7 +60,7 @@ Visualize COLMAP sparse reconstruction outputs. To get demo data, see ``./assets def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None - client.camera.up_direction = tf.SO3(client.camera.wxyz) @ onp.array( + client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array( [0.0, -1.0, 0.0] ) @@ -82,10 +82,10 @@ Visualize COLMAP sparse reconstruction outputs. To get demo data, see ``./assets "Point size", min=0.01, max=0.1, step=0.001, initial_value=0.05 ) - points = onp.array([points3d[p_id].xyz for p_id in points3d]) - colors = onp.array([points3d[p_id].rgb for p_id in points3d]) + points = np.array([points3d[p_id].xyz for p_id in points3d]) + colors = np.array([points3d[p_id].rgb for p_id in points3d]) - point_mask = onp.random.choice(points.shape[0], gui_points.value, replace=False) + point_mask = np.random.choice(points.shape[0], gui_points.value, replace=False) point_cloud = server.scene.add_point_cloud( name="/colmap/pcd", points=points[point_mask], @@ -148,7 +148,7 @@ Visualize COLMAP sparse reconstruction outputs. To get demo data, see ``./assets image = image[::downsample_factor, ::downsample_factor] frustum = server.scene.add_camera_frustum( f"/colmap/frame_{img_id}/frustum", - fov=2 * onp.arctan2(H / 2, fy), + fov=2 * np.arctan2(H / 2, fy), aspect=W / H, scale=0.15, image=image, @@ -159,7 +159,7 @@ Visualize COLMAP sparse reconstruction outputs. To get demo data, see ``./assets @gui_points.on_update def _(_) -> None: - point_mask = onp.random.choice(points.shape[0], gui_points.value, replace=False) + point_mask = np.random.choice(points.shape[0], gui_points.value, replace=False) point_cloud.points = points[point_mask] point_cloud.colors = colors[point_mask] diff --git a/docs/source/examples/15_gui_in_scene.rst b/docs/source/examples/15_gui_in_scene.rst index fa9c1ff3f..c0a199c68 100644 --- a/docs/source/examples/15_gui_in_scene.rst +++ b/docs/source/examples/15_gui_in_scene.rst @@ -18,7 +18,7 @@ performed on them. import time from typing import Optional - import numpy as onp + import numpy as np import viser import viser.transforms as tf @@ -35,14 +35,14 @@ performed on them. When a frame is clicked, we display a 3D gui node. """ - rng = onp.random.default_rng(0) + rng = np.random.default_rng(0) displayed_3d_container: Optional[viser.Gui3dContainerHandle] = None def make_frame(i: int) -> None: # Sample a random orientation + position. wxyz = rng.normal(size=4) - wxyz /= onp.linalg.norm(wxyz) + wxyz /= np.linalg.norm(wxyz) position = rng.uniform(-3.0, 3.0, size=(3,)) # Create a coordinate frame and label. @@ -72,7 +72,7 @@ performed on them. ) T_world_target = tf.SE3.from_rotation_and_translation( tf.SO3(frame.wxyz), frame.position - ) @ tf.SE3.from_translation(onp.array([0.0, 0.0, -0.5])) + ) @ tf.SE3.from_translation(np.array([0.0, 0.0, -0.5])) T_current_target = T_world_current.inverse() @ T_world_target @@ -94,7 +94,7 @@ performed on them. @randomize_orientation.on_click def _(_) -> None: wxyz = rng.normal(size=4) - wxyz /= onp.linalg.norm(wxyz) + wxyz /= np.linalg.norm(wxyz) frame.wxyz = wxyz @close.on_click diff --git a/docs/source/examples/17_background_composite.rst b/docs/source/examples/17_background_composite.rst index 2ba59d3e4..2e9fcd36b 100644 --- a/docs/source/examples/17_background_composite.rst +++ b/docs/source/examples/17_background_composite.rst @@ -16,7 +16,7 @@ be useful when we want a 2D image to occlude 3D geometry, such as for NeRF rende import time - import numpy as onp + import numpy as np import trimesh import trimesh.creation @@ -25,8 +25,8 @@ be useful when we want a 2D image to occlude 3D geometry, such as for NeRF rende server = viser.ViserServer() - img = onp.random.randint(0, 255, size=(1000, 1000, 3), dtype=onp.uint8) - depth = onp.ones((1000, 1000, 1), dtype=onp.float32) + img = np.random.randint(0, 255, size=(1000, 1000, 3), dtype=np.uint8) + depth = np.ones((1000, 1000, 1), dtype=np.float32) # Make a square middle portal. depth[250:750, 250:750, :] = 10.0 diff --git a/docs/source/examples/18_splines.rst b/docs/source/examples/18_splines.rst index 2cb804c2c..161862490 100644 --- a/docs/source/examples/18_splines.rst +++ b/docs/source/examples/18_splines.rst @@ -15,7 +15,7 @@ Make a ball with some random splines. import time - import numpy as onp + import numpy as np import viser @@ -23,23 +23,23 @@ Make a ball with some random splines. def main() -> None: server = viser.ViserServer() for i in range(10): - positions = onp.random.normal(size=(30, 3)) * 3.0 + positions = np.random.normal(size=(30, 3)) * 3.0 server.scene.add_spline_catmull_rom( f"/catmull_{i}", positions, tension=0.5, line_width=3.0, - color=onp.random.uniform(size=3), + color=np.random.uniform(size=3), segments=100, ) - control_points = onp.random.normal(size=(30 * 2 - 2, 3)) * 3.0 + control_points = np.random.normal(size=(30 * 2 - 2, 3)) * 3.0 server.scene.add_spline_cubic_bezier( f"/cubic_bezier_{i}", positions, control_points, line_width=3.0, - color=onp.random.uniform(size=3), + color=np.random.uniform(size=3), segments=100, ) diff --git a/docs/source/examples/19_get_renders.rst b/docs/source/examples/19_get_renders.rst index c820b7e1f..b00541454 100644 --- a/docs/source/examples/19_get_renders.rst +++ b/docs/source/examples/19_get_renders.rst @@ -16,7 +16,7 @@ Example for getting renders from a client's viewport to the Python API. import time import imageio.v3 as iio - import numpy as onp + import numpy as np import viser @@ -36,13 +36,13 @@ Example for getting renders from a client's viewport to the Python API. images = [] for i in range(20): - positions = onp.random.normal(size=(30, 3)) * 3.0 + positions = np.random.normal(size=(30, 3)) * 3.0 client.scene.add_spline_catmull_rom( f"/catmull_{i}", positions, tension=0.5, line_width=3.0, - color=onp.random.uniform(size=3), + color=np.random.uniform(size=3), ) images.append(client.camera.get_render(height=720, width=1280)) diff --git a/docs/source/examples/20_scene_pointer.rst b/docs/source/examples/20_scene_pointer.rst index 2a85b1f9c..19d7ba9a2 100644 --- a/docs/source/examples/20_scene_pointer.rst +++ b/docs/source/examples/20_scene_pointer.rst @@ -22,7 +22,7 @@ To get the demo data, see ``./assets/download_dragon_mesh.sh``. from pathlib import Path from typing import cast - import numpy as onp + import numpy as np import trimesh import trimesh.creation import trimesh.ray @@ -56,8 +56,8 @@ To get the demo data, see ``./assets/download_dragon_mesh.sh``. @server.on_client_connect def _(client: viser.ClientHandle) -> None: # Set up the camera -- this gives a nice view of the full mesh. - client.camera.position = onp.array([0.0, 0.0, -10.0]) - client.camera.wxyz = onp.array([0.0, 0.0, 0.0, 1.0]) + client.camera.position = np.array([0.0, 0.0, -10.0]) + client.camera.wxyz = np.array([0.0, 0.0, 0.0, 1.0]) # Tests "click" scenepointerevent. click_button_handle = client.gui.add_button("Add sphere", icon=viser.Icon.POINTER) @@ -72,8 +72,8 @@ To get the demo data, see ``./assets/download_dragon_mesh.sh``. # Note that mesh is in the mesh frame, so we need to transform the ray. R_world_mesh = tf.SO3(mesh_handle.wxyz) R_mesh_world = R_world_mesh.inverse() - origin = (R_mesh_world @ onp.array(event.ray_origin)).reshape(1, 3) - direction = (R_mesh_world @ onp.array(event.ray_direction)).reshape(1, 3) + origin = (R_mesh_world @ np.array(event.ray_origin)).reshape(1, 3) + direction = (R_mesh_world @ np.array(event.ray_direction)).reshape(1, 3) intersector = trimesh.ray.ray_triangle.RayMeshIntersector(mesh) hit_pos, _, _ = intersector.intersects_location(origin, direction) @@ -82,7 +82,7 @@ To get the demo data, see ``./assets/download_dragon_mesh.sh``. client.scene.remove_pointer_callback() # Get the first hit position (based on distance from the ray origin). - hit_pos = hit_pos[onp.argmin(onp.sum((hit_pos - origin) ** 2, axis=-1))] + hit_pos = hit_pos[np.argmin(np.sum((hit_pos - origin) ** 2, axis=-1))] # Create a sphere at the hit location. hit_pos_mesh = trimesh.creation.icosphere(radius=0.1) @@ -117,17 +117,17 @@ To get the demo data, see ``./assets/download_dragon_mesh.sh``. R_camera_world = tf.SE3.from_rotation_and_translation( tf.SO3(camera.wxyz), camera.position ).inverse() - vertices = cast(onp.ndarray, mesh.vertices) + vertices = cast(np.ndarray, mesh.vertices) vertices = (R_mesh_world.as_matrix() @ vertices.T).T vertices = ( R_camera_world.as_matrix() - @ onp.hstack([vertices, onp.ones((vertices.shape[0], 1))]).T + @ np.hstack([vertices, np.ones((vertices.shape[0], 1))]).T ).T[:, :3] # Get the camera intrinsics, and project the vertices onto the image plane. fov, aspect = camera.fov, camera.aspect vertices_proj = vertices[:, :2] / vertices[:, 2].reshape(-1, 1) - vertices_proj /= onp.tan(fov / 2) + vertices_proj /= np.tan(fov / 2) vertices_proj[:, 0] /= aspect # Move the origin to the upper-left corner, and scale to [0, 1]. @@ -136,12 +136,12 @@ To get the demo data, see ``./assets/download_dragon_mesh.sh``. # Select the vertices that lie inside the 2D selected box, once projected. mask = ( - (vertices_proj > onp.array(message.screen_pos[0])) - & (vertices_proj < onp.array(message.screen_pos[1])) + (vertices_proj > np.array(message.screen_pos[0])) + & (vertices_proj < np.array(message.screen_pos[1])) ).all(axis=1)[..., None] # Update the mesh color based on whether the vertices are inside the box - mesh.visual.vertex_colors = onp.where( # type: ignore + mesh.visual.vertex_colors = np.where( # type: ignore mask, (0.5, 0.0, 0.7, 1.0), (0.9, 0.9, 0.9, 1.0) ) mesh_handle = server.scene.add_mesh_trimesh( diff --git a/docs/source/examples/22_games.rst b/docs/source/examples/22_games.rst index c06831bac..4064a7f85 100644 --- a/docs/source/examples/22_games.rst +++ b/docs/source/examples/22_games.rst @@ -16,7 +16,7 @@ Some two-player games implemented using scene click events. import time from typing import Literal - import numpy as onp + import numpy as np import trimesh.creation from typing_extensions import assert_never @@ -53,7 +53,7 @@ Some two-player games implemented using scene click events. f"/structure/{row}_{col}", trimesh.creation.annulus(0.45, 0.55, 0.125), position=(0.0, col, row), - wxyz=tf.SO3.from_y_radians(onp.pi / 2.0).wxyz, + wxyz=tf.SO3.from_y_radians(np.pi / 2.0).wxyz, ) # Create a sphere to click on for each column. @@ -81,10 +81,10 @@ Some two-player games implemented using scene click events. f"/game_pieces/{row}_{col}", cylinder.vertices, cylinder.faces, - wxyz=tf.SO3.from_y_radians(onp.pi / 2.0).wxyz, + wxyz=tf.SO3.from_y_radians(np.pi / 2.0).wxyz, color={"red": (255, 0, 0), "yellow": (255, 255, 0)}[whose_turn], ) - for row_anim in onp.linspace(num_rows - 1, row, num_rows - row + 1): + for row_anim in np.linspace(num_rows - 1, row, num_rows - row + 1): piece.position = ( 0, col, @@ -108,12 +108,12 @@ Some two-player games implemented using scene click events. ((-0.5, -1.5, 0), (-0.5, 1.5, 0)), color=(127, 127, 127), position=(1, 1, 0), - wxyz=tf.SO3.from_z_radians(onp.pi / 2 * i).wxyz, + wxyz=tf.SO3.from_z_radians(np.pi / 2 * i).wxyz, ) def draw_symbol(symbol: Literal["x", "o"], i: int, j: int) -> None: """Draw an X or O in the given cell.""" - for scale in onp.linspace(0.01, 1.0, 5): + for scale in np.linspace(0.01, 1.0, 5): if symbol == "x": for k in range(2): server.scene.add_box( @@ -121,9 +121,7 @@ Some two-player games implemented using scene click events. dimensions=(0.7 * scale, 0.125 * scale, 0.125), position=(i, j, 0), color=(0, 0, 255), - wxyz=tf.SO3.from_z_radians( - onp.pi / 2.0 * k + onp.pi / 4.0 - ).wxyz, + wxyz=tf.SO3.from_z_radians(np.pi / 2.0 * k + np.pi / 4.0).wxyz, ) elif symbol == "o": mesh = trimesh.creation.annulus(0.25 * scale, 0.35 * scale, 0.125) diff --git a/docs/source/examples/23_plotly.rst b/docs/source/examples/23_plotly.rst index 87338f05b..013206c48 100644 --- a/docs/source/examples/23_plotly.rst +++ b/docs/source/examples/23_plotly.rst @@ -15,7 +15,7 @@ Examples of visualizing plotly plots in Viser. import time - import numpy as onp + import numpy as np import plotly.express as px import plotly.graph_objects as go from PIL import Image @@ -25,8 +25,8 @@ Examples of visualizing plotly plots in Viser. def create_sinusoidal_wave(t: float) -> go.Figure: """Create a sinusoidal wave plot, starting at time t.""" - x_data = onp.linspace(t, t + 6 * onp.pi, 50) - y_data = onp.sin(x_data) * 10 + x_data = np.linspace(t, t + 6 * np.pi, 50) + y_data = np.sin(x_data) * 10 fig = px.line( x=list(x_data), diff --git a/docs/source/examples/25_smpl_visualizer_skinned.rst b/docs/source/examples/25_smpl_visualizer_skinned.rst index b5294cb1f..caa8a5224 100644 --- a/docs/source/examples/25_smpl_visualizer_skinned.rst +++ b/docs/source/examples/25_smpl_visualizer_skinned.rst @@ -24,7 +24,6 @@ See here for download instructions: from typing import List, Tuple import numpy as np - import numpy as onp import tyro import viser @@ -44,7 +43,7 @@ See here for download instructions: def __init__(self, model_path: Path) -> None: assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!" - body_dict = dict(**onp.load(model_path, allow_pickle=True)) + body_dict = dict(**np.load(model_path, allow_pickle=True)) self._J_regressor = body_dict["J_regressor"] self._weights = body_dict["weights"] @@ -100,7 +99,7 @@ See here for download instructions: ) smpl_outputs = model.get_outputs( betas=np.array([x.value for x in gui_elements.gui_betas]), - joint_rotmats=onp.zeros((model.num_joints, 3, 3)) + onp.eye(3), + joint_rotmats=np.zeros((model.num_joints, 3, 3)) + np.eye(3), ) bone_wxyzs = np.array( @@ -127,6 +126,9 @@ See here for download instructions: gui_elements.changed = False + # Render as wireframe? + skinned_handle.wireframe = gui_elements.gui_wireframe.value + # Compute SMPL outputs. smpl_outputs = model.get_outputs( betas=np.array([x.value for x in gui_elements.gui_betas]), @@ -215,7 +217,7 @@ See here for download instructions: @gui_random_shape.on_click def _(_): for beta in gui_betas: - beta.value = onp.random.normal(loc=0.0, scale=1.0) + beta.value = np.random.normal(loc=0.0, scale=1.0) gui_betas = [] for i in range(num_betas): @@ -240,8 +242,8 @@ See here for download instructions: for joint in gui_joints: # It's hard to uniformly sample orientations directly in so(3), so we # first sample on S^3 and then convert. - quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,)) - quat /= onp.linalg.norm(quat) + quat = np.random.normal(loc=0.0, scale=1.0, size=(4,)) + quat /= np.linalg.norm(quat) joint.value = tf.SO3(wxyz=quat).log() gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = [] diff --git a/docs/source/examples/26_lighting.rst b/docs/source/examples/26_lighting.rst index 6eee7c27c..c64a0dbb1 100644 --- a/docs/source/examples/26_lighting.rst +++ b/docs/source/examples/26_lighting.rst @@ -16,7 +16,7 @@ Visualize a mesh under different lighting conditions. To get the demo data, see import time from pathlib import Path - import numpy as onp + import numpy as np import trimesh import viser @@ -39,13 +39,13 @@ Visualize a mesh under different lighting conditions. To get the demo data, see name="/simple", vertices=vertices, faces=faces, - wxyz=tf.SO3.from_x_radians(onp.pi / 2).wxyz, + wxyz=tf.SO3.from_x_radians(np.pi / 2).wxyz, position=(0.0, 0.0, 0.0), ) server.scene.add_mesh_trimesh( name="/trimesh", mesh=mesh, - wxyz=tf.SO3.from_x_radians(onp.pi / 2).wxyz, + wxyz=tf.SO3.from_x_radians(np.pi / 2).wxyz, position=(0.0, 5.0, 0.0), ) diff --git a/examples/01_image.py b/examples/01_image.py index 86a2a6860..1ad83da57 100644 --- a/examples/01_image.py +++ b/examples/01_image.py @@ -10,7 +10,7 @@ from pathlib import Path import imageio.v3 as iio -import numpy as onp +import numpy as np import viser @@ -37,12 +37,7 @@ def main() -> None: while True: server.scene.add_image( "/noise", - onp.random.randint( - 0, - 256, - size=(400, 400, 3), - dtype=onp.uint8, - ), + np.random.randint(0, 256, size=(400, 400, 3), dtype=np.uint8), 4.0, 4.0, format="jpeg", diff --git a/examples/02_gui.py b/examples/02_gui.py index e48e63fd7..e97719c1f 100644 --- a/examples/02_gui.py +++ b/examples/02_gui.py @@ -4,7 +4,7 @@ import time -import numpy as onp +import numpy as np import viser @@ -85,8 +85,8 @@ def _(_) -> None: print(file.name, len(file.content), "bytes") # Pre-generate a point cloud to send. - point_positions = onp.random.uniform(low=-1.0, high=1.0, size=(5000, 3)) - color_coeffs = onp.random.uniform(0.4, 1.0, size=(point_positions.shape[0])) + point_positions = np.random.uniform(low=-1.0, high=1.0, size=(5000, 3)) + color_coeffs = np.random.uniform(0.4, 1.0, size=(point_positions.shape[0])) counter = 0 while True: @@ -100,11 +100,11 @@ def _(_) -> None: # connected clients. server.scene.add_point_cloud( "/point_cloud", - points=point_positions * onp.array(gui_vector3.value, dtype=onp.float32), + points=point_positions * np.array(gui_vector3.value, dtype=np.float32), colors=( - onp.tile(gui_rgb.value, point_positions.shape[0]).reshape((-1, 3)) + np.tile(gui_rgb.value, point_positions.shape[0]).reshape((-1, 3)) * color_coeffs[:, None] - ).astype(onp.uint8), + ).astype(np.uint8), position=gui_vector2.value + (0,), point_shape="circle", ) @@ -120,8 +120,8 @@ def _(_) -> None: # Update the number of handles in the multi-slider. if gui_slider_positions.value != len(gui_multi_slider.value): - gui_multi_slider.value = onp.linspace( - 0, 100, gui_slider_positions.value, dtype=onp.int64 + gui_multi_slider.value = np.linspace( + 0, 100, gui_slider_positions.value, dtype=np.int64 ) counter += 1 diff --git a/examples/03_gui_callbacks.py b/examples/03_gui_callbacks.py index a1df139e2..ca1f2e308 100644 --- a/examples/03_gui_callbacks.py +++ b/examples/03_gui_callbacks.py @@ -5,7 +5,7 @@ import time -import numpy as onp +import numpy as np from typing_extensions import assert_never import viser @@ -75,8 +75,8 @@ def draw_points() -> None: num_points = gui_num_points.value server.scene.add_point_cloud( "/frame/point_cloud", - points=onp.random.normal(size=(num_points, 3)), - colors=onp.random.randint(0, 256, size=(num_points, 3)), + points=np.random.normal(size=(num_points, 3)), + colors=np.random.randint(0, 256, size=(num_points, 3)), ) # We can (optionally) also attach callbacks! diff --git a/examples/05_camera_commands.py b/examples/05_camera_commands.py index 4a51aeaa5..744dde64e 100644 --- a/examples/05_camera_commands.py +++ b/examples/05_camera_commands.py @@ -6,7 +6,7 @@ import time -import numpy as onp +import numpy as np import viser import viser.transforms as tf @@ -22,12 +22,12 @@ def _(client: viser.ClientHandle) -> None: When a frame is clicked, we move the camera to the corresponding frame. """ - rng = onp.random.default_rng(0) + rng = np.random.default_rng(0) def make_frame(i: int) -> None: # Sample a random orientation + position. wxyz = rng.normal(size=4) - wxyz /= onp.linalg.norm(wxyz) + wxyz /= np.linalg.norm(wxyz) position = rng.uniform(-3.0, 3.0, size=(3,)) # Create a coordinate frame and label. @@ -42,7 +42,7 @@ def _(_): ) T_world_target = tf.SE3.from_rotation_and_translation( tf.SO3(frame.wxyz), frame.position - ) @ tf.SE3.from_translation(onp.array([0.0, 0.0, -0.5])) + ) @ tf.SE3.from_translation(np.array([0.0, 0.0, -0.5])) T_current_target = T_world_current.inverse() @ T_world_target diff --git a/examples/06_mesh.py b/examples/06_mesh.py index 3d30f09e5..7babe8894 100644 --- a/examples/06_mesh.py +++ b/examples/06_mesh.py @@ -6,7 +6,7 @@ import time from pathlib import Path -import numpy as onp +import numpy as np import trimesh import viser @@ -25,13 +25,13 @@ name="/simple", vertices=vertices, faces=faces, - wxyz=tf.SO3.from_x_radians(onp.pi / 2).wxyz, + wxyz=tf.SO3.from_x_radians(np.pi / 2).wxyz, position=(0.0, 0.0, 0.0), ) server.scene.add_mesh_trimesh( name="/trimesh", mesh=mesh.smoothed(), - wxyz=tf.SO3.from_x_radians(onp.pi / 2).wxyz, + wxyz=tf.SO3.from_x_radians(np.pi / 2).wxyz, position=(0.0, 5.0, 0.0), ) diff --git a/examples/07_record3d_visualizer.py b/examples/07_record3d_visualizer.py index e8b926bbb..aedd58fa1 100644 --- a/examples/07_record3d_visualizer.py +++ b/examples/07_record3d_visualizer.py @@ -6,7 +6,7 @@ import time from pathlib import Path -import numpy as onp +import numpy as np import tyro from tqdm.auto import tqdm @@ -86,7 +86,7 @@ def _(_) -> None: # Load in frames. server.scene.add_frame( "/frames", - wxyz=tf.SO3.exp(onp.array([onp.pi / 2.0, 0.0, 0.0])).wxyz, + wxyz=tf.SO3.exp(np.array([np.pi / 2.0, 0.0, 0.0])).wxyz, position=(0, 0, 0), show_axes=False, ) @@ -108,7 +108,7 @@ def _(_) -> None: ) # Place the frustum. - fov = 2 * onp.arctan2(frame.rgb.shape[0] / 2, frame.K[0, 0]) + fov = 2 * np.arctan2(frame.rgb.shape[0] / 2, frame.K[0, 0]) aspect = frame.rgb.shape[1] / frame.rgb.shape[0] server.scene.add_camera_frustum( f"/frames/t{i}/frustum", diff --git a/examples/08_smpl_visualizer.py b/examples/08_smpl_visualizer.py index 7d89f6b75..8954f6d1a 100644 --- a/examples/08_smpl_visualizer.py +++ b/examples/08_smpl_visualizer.py @@ -13,7 +13,6 @@ from pathlib import Path import numpy as np -import numpy as onp import tyro import viser @@ -33,7 +32,7 @@ class SmplHelper: def __init__(self, model_path: Path) -> None: assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!" - body_dict = dict(**onp.load(model_path, allow_pickle=True)) + body_dict = dict(**np.load(model_path, allow_pickle=True)) self._J_regressor = body_dict["J_regressor"] self._weights = body_dict["weights"] @@ -170,7 +169,7 @@ def _(_): @gui_random_shape.on_click def _(_): for beta in gui_betas: - beta.value = onp.random.normal(loc=0.0, scale=1.0) + beta.value = np.random.normal(loc=0.0, scale=1.0) gui_betas = [] for i in range(num_betas): @@ -195,8 +194,8 @@ def _(_): for joint in gui_joints: # It's hard to uniformly sample orientations directly in so(3), so we # first sample on S^3 and then convert. - quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,)) - quat /= onp.linalg.norm(quat) + quat = np.random.normal(loc=0.0, scale=1.0, size=(4,)) + quat /= np.linalg.norm(quat) joint.value = tf.SO3(wxyz=quat).log() gui_joints: list[viser.GuiInputHandle[tuple[float, float, float]]] = [] diff --git a/examples/09_urdf_visualizer.py b/examples/09_urdf_visualizer.py index e1ceb46d9..cf1b57f81 100644 --- a/examples/09_urdf_visualizer.py +++ b/examples/09_urdf_visualizer.py @@ -13,7 +13,7 @@ import time from typing import Literal -import numpy as onp +import numpy as np import tyro from robot_descriptions.loaders.yourdfpy import load_robot_description @@ -32,8 +32,8 @@ def create_robot_control_sliders( lower, upper, ) in viser_urdf.get_actuated_joint_limits().items(): - lower = lower if lower is not None else -onp.pi - upper = upper if upper is not None else onp.pi + lower = lower if lower is not None else -np.pi + upper = upper if upper is not None else np.pi initial_pos = 0.0 if lower < 0 and upper > 0 else (lower + upper) / 2.0 slider = server.gui.add_slider( label=joint_name, @@ -44,7 +44,7 @@ def create_robot_control_sliders( ) slider.on_update( # When sliders move, we update the URDF configuration. lambda _: viser_urdf.update_cfg( - onp.array([slider.value for slider in slider_handles]) + np.array([slider.value for slider in slider_handles]) ) ) slider_handles.append(slider) @@ -85,7 +85,7 @@ def main( ) # Set initial robot configuration. - viser_urdf.update_cfg(onp.array(initial_config)) + viser_urdf.update_cfg(np.array(initial_config)) # Create joint reset button. reset_button = server.gui.add_button("Reset") diff --git a/examples/11_colmap_visualizer.py b/examples/11_colmap_visualizer.py index b640d65a9..762ca4094 100644 --- a/examples/11_colmap_visualizer.py +++ b/examples/11_colmap_visualizer.py @@ -9,7 +9,7 @@ from typing import List import imageio.v3 as iio -import numpy as onp +import numpy as np import tyro from tqdm.auto import tqdm @@ -50,7 +50,7 @@ def main( def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None - client.camera.up_direction = tf.SO3(client.camera.wxyz) @ onp.array( + client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array( [0.0, -1.0, 0.0] ) @@ -72,10 +72,10 @@ def _(event: viser.GuiEvent) -> None: "Point size", min=0.01, max=0.1, step=0.001, initial_value=0.05 ) - points = onp.array([points3d[p_id].xyz for p_id in points3d]) - colors = onp.array([points3d[p_id].rgb for p_id in points3d]) + points = np.array([points3d[p_id].xyz for p_id in points3d]) + colors = np.array([points3d[p_id].rgb for p_id in points3d]) - point_mask = onp.random.choice(points.shape[0], gui_points.value, replace=False) + point_mask = np.random.choice(points.shape[0], gui_points.value, replace=False) point_cloud = server.scene.add_point_cloud( name="/colmap/pcd", points=points[point_mask], @@ -138,7 +138,7 @@ def _(_) -> None: image = image[::downsample_factor, ::downsample_factor] frustum = server.scene.add_camera_frustum( f"/colmap/frame_{img_id}/frustum", - fov=2 * onp.arctan2(H / 2, fy), + fov=2 * np.arctan2(H / 2, fy), aspect=W / H, scale=0.15, image=image, @@ -149,7 +149,7 @@ def _(_) -> None: @gui_points.on_update def _(_) -> None: - point_mask = onp.random.choice(points.shape[0], gui_points.value, replace=False) + point_mask = np.random.choice(points.shape[0], gui_points.value, replace=False) point_cloud.points = points[point_mask] point_cloud.colors = colors[point_mask] diff --git a/examples/15_gui_in_scene.py b/examples/15_gui_in_scene.py index 339028b44..233b83c94 100644 --- a/examples/15_gui_in_scene.py +++ b/examples/15_gui_in_scene.py @@ -8,7 +8,7 @@ import time from typing import Optional -import numpy as onp +import numpy as np import viser import viser.transforms as tf @@ -25,14 +25,14 @@ def _(client: viser.ClientHandle) -> None: When a frame is clicked, we display a 3D gui node. """ - rng = onp.random.default_rng(0) + rng = np.random.default_rng(0) displayed_3d_container: Optional[viser.Gui3dContainerHandle] = None def make_frame(i: int) -> None: # Sample a random orientation + position. wxyz = rng.normal(size=4) - wxyz /= onp.linalg.norm(wxyz) + wxyz /= np.linalg.norm(wxyz) position = rng.uniform(-3.0, 3.0, size=(3,)) # Create a coordinate frame and label. @@ -62,7 +62,7 @@ def _(_) -> None: ) T_world_target = tf.SE3.from_rotation_and_translation( tf.SO3(frame.wxyz), frame.position - ) @ tf.SE3.from_translation(onp.array([0.0, 0.0, -0.5])) + ) @ tf.SE3.from_translation(np.array([0.0, 0.0, -0.5])) T_current_target = T_world_current.inverse() @ T_world_target @@ -84,7 +84,7 @@ def _(_) -> None: @randomize_orientation.on_click def _(_) -> None: wxyz = rng.normal(size=4) - wxyz /= onp.linalg.norm(wxyz) + wxyz /= np.linalg.norm(wxyz) frame.wxyz = wxyz @close.on_click diff --git a/examples/17_background_composite.py b/examples/17_background_composite.py index 6098f02f3..1750bf80c 100644 --- a/examples/17_background_composite.py +++ b/examples/17_background_composite.py @@ -6,7 +6,7 @@ import time -import numpy as onp +import numpy as np import trimesh import trimesh.creation @@ -15,8 +15,8 @@ server = viser.ViserServer() -img = onp.random.randint(0, 255, size=(1000, 1000, 3), dtype=onp.uint8) -depth = onp.ones((1000, 1000, 1), dtype=onp.float32) +img = np.random.randint(0, 255, size=(1000, 1000, 3), dtype=np.uint8) +depth = np.ones((1000, 1000, 1), dtype=np.float32) # Make a square middle portal. depth[250:750, 250:750, :] = 10.0 diff --git a/examples/18_splines.py b/examples/18_splines.py index a6d42ebcf..8566dbb9c 100644 --- a/examples/18_splines.py +++ b/examples/18_splines.py @@ -5,7 +5,7 @@ import time -import numpy as onp +import numpy as np import viser @@ -13,23 +13,23 @@ def main() -> None: server = viser.ViserServer() for i in range(10): - positions = onp.random.normal(size=(30, 3)) * 3.0 + positions = np.random.normal(size=(30, 3)) * 3.0 server.scene.add_spline_catmull_rom( f"/catmull_{i}", positions, tension=0.5, line_width=3.0, - color=onp.random.uniform(size=3), + color=np.random.uniform(size=3), segments=100, ) - control_points = onp.random.normal(size=(30 * 2 - 2, 3)) * 3.0 + control_points = np.random.normal(size=(30 * 2 - 2, 3)) * 3.0 server.scene.add_spline_cubic_bezier( f"/cubic_bezier_{i}", positions, control_points, line_width=3.0, - color=onp.random.uniform(size=3), + color=np.random.uniform(size=3), segments=100, ) diff --git a/examples/19_get_renders.py b/examples/19_get_renders.py index 2b599c401..890bd03aa 100644 --- a/examples/19_get_renders.py +++ b/examples/19_get_renders.py @@ -5,7 +5,7 @@ import time import imageio.v3 as iio -import numpy as onp +import numpy as np import viser @@ -25,13 +25,13 @@ def _(event: viser.GuiEvent) -> None: images = [] for i in range(20): - positions = onp.random.normal(size=(30, 3)) * 3.0 + positions = np.random.normal(size=(30, 3)) * 3.0 client.scene.add_spline_catmull_rom( f"/catmull_{i}", positions, tension=0.5, line_width=3.0, - color=onp.random.uniform(size=3), + color=np.random.uniform(size=3), ) images.append(client.camera.get_render(height=720, width=1280)) diff --git a/examples/20_scene_pointer.py b/examples/20_scene_pointer.py index e18fac852..17f1b09ed 100644 --- a/examples/20_scene_pointer.py +++ b/examples/20_scene_pointer.py @@ -12,7 +12,7 @@ from pathlib import Path from typing import cast -import numpy as onp +import numpy as np import trimesh import trimesh.creation import trimesh.ray @@ -46,8 +46,8 @@ @server.on_client_connect def _(client: viser.ClientHandle) -> None: # Set up the camera -- this gives a nice view of the full mesh. - client.camera.position = onp.array([0.0, 0.0, -10.0]) - client.camera.wxyz = onp.array([0.0, 0.0, 0.0, 1.0]) + client.camera.position = np.array([0.0, 0.0, -10.0]) + client.camera.wxyz = np.array([0.0, 0.0, 0.0, 1.0]) # Tests "click" scenepointerevent. click_button_handle = client.gui.add_button("Add sphere", icon=viser.Icon.POINTER) @@ -62,8 +62,8 @@ def _(event: viser.ScenePointerEvent) -> None: # Note that mesh is in the mesh frame, so we need to transform the ray. R_world_mesh = tf.SO3(mesh_handle.wxyz) R_mesh_world = R_world_mesh.inverse() - origin = (R_mesh_world @ onp.array(event.ray_origin)).reshape(1, 3) - direction = (R_mesh_world @ onp.array(event.ray_direction)).reshape(1, 3) + origin = (R_mesh_world @ np.array(event.ray_origin)).reshape(1, 3) + direction = (R_mesh_world @ np.array(event.ray_direction)).reshape(1, 3) intersector = trimesh.ray.ray_triangle.RayMeshIntersector(mesh) hit_pos, _, _ = intersector.intersects_location(origin, direction) @@ -72,7 +72,7 @@ def _(event: viser.ScenePointerEvent) -> None: client.scene.remove_pointer_callback() # Get the first hit position (based on distance from the ray origin). - hit_pos = hit_pos[onp.argmin(onp.sum((hit_pos - origin) ** 2, axis=-1))] + hit_pos = hit_pos[np.argmin(np.sum((hit_pos - origin) ** 2, axis=-1))] # Create a sphere at the hit location. hit_pos_mesh = trimesh.creation.icosphere(radius=0.1) @@ -107,17 +107,17 @@ def _(message: viser.ScenePointerEvent) -> None: R_camera_world = tf.SE3.from_rotation_and_translation( tf.SO3(camera.wxyz), camera.position ).inverse() - vertices = cast(onp.ndarray, mesh.vertices) + vertices = cast(np.ndarray, mesh.vertices) vertices = (R_mesh_world.as_matrix() @ vertices.T).T vertices = ( R_camera_world.as_matrix() - @ onp.hstack([vertices, onp.ones((vertices.shape[0], 1))]).T + @ np.hstack([vertices, np.ones((vertices.shape[0], 1))]).T ).T[:, :3] # Get the camera intrinsics, and project the vertices onto the image plane. fov, aspect = camera.fov, camera.aspect vertices_proj = vertices[:, :2] / vertices[:, 2].reshape(-1, 1) - vertices_proj /= onp.tan(fov / 2) + vertices_proj /= np.tan(fov / 2) vertices_proj[:, 0] /= aspect # Move the origin to the upper-left corner, and scale to [0, 1]. @@ -126,12 +126,12 @@ def _(message: viser.ScenePointerEvent) -> None: # Select the vertices that lie inside the 2D selected box, once projected. mask = ( - (vertices_proj > onp.array(message.screen_pos[0])) - & (vertices_proj < onp.array(message.screen_pos[1])) + (vertices_proj > np.array(message.screen_pos[0])) + & (vertices_proj < np.array(message.screen_pos[1])) ).all(axis=1)[..., None] # Update the mesh color based on whether the vertices are inside the box - mesh.visual.vertex_colors = onp.where( # type: ignore + mesh.visual.vertex_colors = np.where( # type: ignore mask, (0.5, 0.0, 0.7, 1.0), (0.9, 0.9, 0.9, 1.0) ) mesh_handle = server.scene.add_mesh_trimesh( diff --git a/examples/22_games.py b/examples/22_games.py index 1ff2f33ad..59f4556cd 100644 --- a/examples/22_games.py +++ b/examples/22_games.py @@ -5,7 +5,7 @@ import time from typing import Literal -import numpy as onp +import numpy as np import trimesh.creation from typing_extensions import assert_never @@ -42,7 +42,7 @@ def play_connect_4(server: viser.ViserServer) -> None: f"/structure/{row}_{col}", trimesh.creation.annulus(0.45, 0.55, 0.125), position=(0.0, col, row), - wxyz=tf.SO3.from_y_radians(onp.pi / 2.0).wxyz, + wxyz=tf.SO3.from_y_radians(np.pi / 2.0).wxyz, ) # Create a sphere to click on for each column. @@ -70,10 +70,10 @@ def _(_) -> None: f"/game_pieces/{row}_{col}", cylinder.vertices, cylinder.faces, - wxyz=tf.SO3.from_y_radians(onp.pi / 2.0).wxyz, + wxyz=tf.SO3.from_y_radians(np.pi / 2.0).wxyz, color={"red": (255, 0, 0), "yellow": (255, 255, 0)}[whose_turn], ) - for row_anim in onp.linspace(num_rows - 1, row, num_rows - row + 1): + for row_anim in np.linspace(num_rows - 1, row, num_rows - row + 1): piece.position = ( 0, col, @@ -97,12 +97,12 @@ def play_tic_tac_toe(server: viser.ViserServer) -> None: ((-0.5, -1.5, 0), (-0.5, 1.5, 0)), color=(127, 127, 127), position=(1, 1, 0), - wxyz=tf.SO3.from_z_radians(onp.pi / 2 * i).wxyz, + wxyz=tf.SO3.from_z_radians(np.pi / 2 * i).wxyz, ) def draw_symbol(symbol: Literal["x", "o"], i: int, j: int) -> None: """Draw an X or O in the given cell.""" - for scale in onp.linspace(0.01, 1.0, 5): + for scale in np.linspace(0.01, 1.0, 5): if symbol == "x": for k in range(2): server.scene.add_box( @@ -110,9 +110,7 @@ def draw_symbol(symbol: Literal["x", "o"], i: int, j: int) -> None: dimensions=(0.7 * scale, 0.125 * scale, 0.125), position=(i, j, 0), color=(0, 0, 255), - wxyz=tf.SO3.from_z_radians( - onp.pi / 2.0 * k + onp.pi / 4.0 - ).wxyz, + wxyz=tf.SO3.from_z_radians(np.pi / 2.0 * k + np.pi / 4.0).wxyz, ) elif symbol == "o": mesh = trimesh.creation.annulus(0.25 * scale, 0.35 * scale, 0.125) diff --git a/examples/23_plotly.py b/examples/23_plotly.py index 4ff284017..cff80d1a6 100644 --- a/examples/23_plotly.py +++ b/examples/23_plotly.py @@ -4,7 +4,7 @@ import time -import numpy as onp +import numpy as np import plotly.express as px import plotly.graph_objects as go from PIL import Image @@ -14,8 +14,8 @@ def create_sinusoidal_wave(t: float) -> go.Figure: """Create a sinusoidal wave plot, starting at time t.""" - x_data = onp.linspace(t, t + 6 * onp.pi, 50) - y_data = onp.sin(x_data) * 10 + x_data = np.linspace(t, t + 6 * np.pi, 50) + y_data = np.sin(x_data) * 10 fig = px.line( x=list(x_data), diff --git a/examples/25_smpl_visualizer_skinned.py b/examples/25_smpl_visualizer_skinned.py index d7b5b4e43..c6c8d6d19 100644 --- a/examples/25_smpl_visualizer_skinned.py +++ b/examples/25_smpl_visualizer_skinned.py @@ -19,7 +19,6 @@ from typing import List, Tuple import numpy as np -import numpy as onp import tyro import viser @@ -39,7 +38,7 @@ class SmplHelper: def __init__(self, model_path: Path) -> None: assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!" - body_dict = dict(**onp.load(model_path, allow_pickle=True)) + body_dict = dict(**np.load(model_path, allow_pickle=True)) self._J_regressor = body_dict["J_regressor"] self._weights = body_dict["weights"] @@ -95,7 +94,7 @@ def main(model_path: Path) -> None: ) smpl_outputs = model.get_outputs( betas=np.array([x.value for x in gui_elements.gui_betas]), - joint_rotmats=onp.zeros((model.num_joints, 3, 3)) + onp.eye(3), + joint_rotmats=np.zeros((model.num_joints, 3, 3)) + np.eye(3), ) bone_wxyzs = np.array( @@ -213,7 +212,7 @@ def _(_): @gui_random_shape.on_click def _(_): for beta in gui_betas: - beta.value = onp.random.normal(loc=0.0, scale=1.0) + beta.value = np.random.normal(loc=0.0, scale=1.0) gui_betas = [] for i in range(num_betas): @@ -238,8 +237,8 @@ def _(_): for joint in gui_joints: # It's hard to uniformly sample orientations directly in so(3), so we # first sample on S^3 and then convert. - quat = onp.random.normal(loc=0.0, scale=1.0, size=(4,)) - quat /= onp.linalg.norm(quat) + quat = np.random.normal(loc=0.0, scale=1.0, size=(4,)) + quat /= np.linalg.norm(quat) joint.value = tf.SO3(wxyz=quat).log() gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = [] diff --git a/examples/26_lighting.py b/examples/26_lighting.py index e2080b4ab..642b119a2 100644 --- a/examples/26_lighting.py +++ b/examples/26_lighting.py @@ -6,7 +6,7 @@ import time from pathlib import Path -import numpy as onp +import numpy as np import trimesh import viser @@ -29,13 +29,13 @@ def main() -> None: name="/simple", vertices=vertices, faces=faces, - wxyz=tf.SO3.from_x_radians(onp.pi / 2).wxyz, + wxyz=tf.SO3.from_x_radians(np.pi / 2).wxyz, position=(0.0, 0.0, 0.0), ) server.scene.add_mesh_trimesh( name="/trimesh", mesh=mesh, - wxyz=tf.SO3.from_x_radians(onp.pi / 2).wxyz, + wxyz=tf.SO3.from_x_radians(np.pi / 2).wxyz, position=(0.0, 5.0, 0.0), ) diff --git a/examples/experimental/gaussian_splats.py b/examples/experimental/gaussian_splats.py index 461e2fab7..cb14d9812 100644 --- a/examples/experimental/gaussian_splats.py +++ b/examples/experimental/gaussian_splats.py @@ -6,8 +6,8 @@ from pathlib import Path from typing import TypedDict -import numpy as onp -import numpy.typing as onpt +import numpy as np +import numpy.typing as npt import tyro from plyfile import PlyData @@ -18,13 +18,13 @@ class SplatFile(TypedDict): """Data loaded from an antimatter15-style splat file.""" - centers: onpt.NDArray[onp.floating] + centers: npt.NDArray[np.floating] """(N, 3).""" - rgbs: onpt.NDArray[onp.floating] + rgbs: npt.NDArray[np.floating] """(N, 3). Range [0, 1].""" - opacities: onpt.NDArray[onp.floating] + opacities: npt.NDArray[np.floating] """(N, 1). Range [0, 1].""" - covariances: onpt.NDArray[onp.floating] + covariances: npt.NDArray[np.floating] """(N, 3, 3).""" @@ -47,18 +47,18 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: num_gaussians = len(splat_buffer) // bytes_per_gaussian # Reinterpret cast to dtypes that we want to extract. - splat_uint8 = onp.frombuffer(splat_buffer, dtype=onp.uint8).reshape( + splat_uint8 = np.frombuffer(splat_buffer, dtype=np.uint8).reshape( (num_gaussians, bytes_per_gaussian) ) - scales = splat_uint8[:, 12:24].copy().view(onp.float32) + scales = splat_uint8[:, 12:24].copy().view(np.float32) wxyzs = splat_uint8[:, 28:32] / 255.0 * 2.0 - 1.0 Rs = tf.SO3(wxyzs).as_matrix() - covariances = onp.einsum( - "nij,njk,nlk->nil", Rs, onp.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs + covariances = np.einsum( + "nij,njk,nlk->nil", Rs, np.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs ) - centers = splat_uint8[:, 0:12].copy().view(onp.float32) + centers = splat_uint8[:, 0:12].copy().view(np.float32) if center: - centers -= onp.mean(centers, axis=0, keepdims=True) + centers -= np.mean(centers, axis=0, keepdims=True) print( f"Splat file with {num_gaussians=} loaded in {time.time() - start_time} seconds" ) @@ -80,18 +80,18 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: plydata = PlyData.read(ply_file_path) v = plydata["vertex"] - positions = onp.stack([v["x"], v["y"], v["z"]], axis=-1) - scales = onp.exp(onp.stack([v["scale_0"], v["scale_1"], v["scale_2"]], axis=-1)) - wxyzs = onp.stack([v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], axis=1) - colors = 0.5 + SH_C0 * onp.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) - opacities = 1.0 / (1.0 + onp.exp(-v["opacity"][:, None])) + positions = np.stack([v["x"], v["y"], v["z"]], axis=-1) + scales = np.exp(np.stack([v["scale_0"], v["scale_1"], v["scale_2"]], axis=-1)) + wxyzs = np.stack([v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], axis=1) + colors = 0.5 + SH_C0 * np.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) + opacities = 1.0 / (1.0 + np.exp(-v["opacity"][:, None])) Rs = tf.SO3(wxyzs).as_matrix() - covariances = onp.einsum( - "nij,njk,nlk->nil", Rs, onp.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs + covariances = np.einsum( + "nij,njk,nlk->nil", Rs, np.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs ) if center: - positions -= onp.mean(positions, axis=0, keepdims=True) + positions -= np.mean(positions, axis=0, keepdims=True) num_gaussians = len(v) print( @@ -117,7 +117,7 @@ def main(splat_paths: tuple[Path, ...]) -> None: def _(event: viser.GuiEvent) -> None: client = event.client assert client is not None - client.camera.up_direction = tf.SO3(client.camera.wxyz) @ onp.array( + client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array( [0.0, -1.0, 0.0] ) diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index 6c3a36186..461b52f07 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -19,7 +19,7 @@ overload, ) -import numpy as onp +import numpy as np from typing_extensions import ( Literal, LiteralString, @@ -420,10 +420,10 @@ def configure_theme( primary_index = 8 ls = tuple( - onp.interp( - x=onp.arange(10), - xp=onp.array([0, primary_index, 9]), - fp=onp.array([max_l, l, min_l]), + np.interp( + x=np.arange(10), + xp=np.array([0, primary_index, 9]), + fp=np.array([max_l, l, min_l]), ) ) colors_cast = cast( @@ -967,7 +967,7 @@ def add_number( # It's ok that `step` is always a float, even if the value is an integer, # because things all become `number` types after serialization. step = float( # type: ignore - onp.min( + np.min( [ _compute_step(value), _compute_step(min), @@ -1006,9 +1006,9 @@ def add_number( def add_vector2( self, label: str, - initial_value: tuple[float, float] | onp.ndarray, - min: tuple[float, float] | onp.ndarray | None = None, - max: tuple[float, float] | onp.ndarray | None = None, + initial_value: tuple[float, float] | np.ndarray, + min: tuple[float, float] | np.ndarray | None = None, + max: tuple[float, float] | np.ndarray | None = None, step: float | None = None, disabled: bool = False, visible: bool = True, @@ -1045,7 +1045,7 @@ def add_vector2( possible_steps.extend([_compute_step(x) for x in min]) if max is not None: possible_steps.extend([_compute_step(x) for x in max]) - step = float(onp.min(possible_steps)) + step = float(np.min(possible_steps)) return GuiVector2Handle( self._create_gui_input( @@ -1072,9 +1072,9 @@ def add_vector2( def add_vector3( self, label: str, - initial_value: tuple[float, float, float] | onp.ndarray, - min: tuple[float, float, float] | onp.ndarray | None = None, - max: tuple[float, float, float] | onp.ndarray | None = None, + initial_value: tuple[float, float, float] | np.ndarray, + min: tuple[float, float, float] | np.ndarray | None = None, + max: tuple[float, float, float] | np.ndarray | None = None, step: float | None = None, disabled: bool = False, visible: bool = True, @@ -1111,7 +1111,7 @@ def add_vector3( possible_steps.extend([_compute_step(x) for x in min]) if max is not None: possible_steps.extend([_compute_step(x) for x in max]) - step = float(onp.min(possible_steps)) + step = float(np.min(possible_steps)) return GuiVector3Handle( self._create_gui_input( diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 41904278a..dac756918 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -22,7 +22,7 @@ ) import imageio.v3 as iio -import numpy as onp +import numpy as np from typing_extensions import Protocol from . import _messages @@ -183,8 +183,8 @@ def value(self) -> T: return self._impl.value @value.setter - def value(self, value: T | onp.ndarray) -> None: - if isinstance(value, onp.ndarray): + def value(self, value: T | np.ndarray) -> None: + if isinstance(value, np.ndarray): assert len(value.shape) <= 1, f"{value.shape} should be at most 1D!" value = tuple(map(float, value)) # type: ignore diff --git a/src/viser/_messages.py b/src/viser/_messages.py index bd3eb951c..0057c7f07 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -7,8 +7,8 @@ import uuid from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union -import numpy as onp -import numpy.typing as onpt +import numpy as np +import numpy.typing as npt from typing_extensions import Literal, override from . import infra, theme @@ -242,9 +242,9 @@ class BatchedAxesMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class BatchedAxesProps: - wxyzs_batched: onpt.NDArray[onp.float32] + wxyzs_batched: npt.NDArray[np.float32] """Float array of shape (N,4) representing quaternion rotations. Synchronized automatically when assigned.""" - positions_batched: onpt.NDArray[onp.float32] + positions_batched: npt.NDArray[np.float32] """Float array of shape (N,3) representing positions. Synchronized automatically when assigned.""" axes_length: float """Length of each axis. Synchronized automatically when assigned.""" @@ -331,9 +331,9 @@ class PointCloudMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class PointCloudProps: - points: onpt.NDArray[onp.float32] + points: npt.NDArray[np.float32] """Location of points. Should have shape (N, 3). Synchronized automatically when assigned.""" - colors: onpt.NDArray[onp.uint8] + colors: npt.NDArray[np.uint8] """Colors of points. Should have shape (N, 3) or (3,). Synchronized automatically when assigned.""" point_size: float """Size of each point. Synchronized automatically when assigned.""" @@ -346,8 +346,8 @@ def __post_init__(self): assert self.points.shape[-1] == 3 # Check dtypes. - assert self.points.dtype == onp.float32 - assert self.colors.dtype == onp.uint8 + assert self.points.dtype == np.float32 + assert self.colors.dtype == np.uint8 @dataclasses.dataclass @@ -465,7 +465,7 @@ class SpotLightProps: """Decay of the spot light. Synchronized automatically when assigned.""" def __post_init__(self): - assert self.angle <= onp.pi / 2 + assert self.angle <= np.pi / 2 assert self.angle >= 0 @@ -514,13 +514,13 @@ class MeshMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class MeshProps: - vertices: onpt.NDArray[onp.float32] + vertices: npt.NDArray[np.float32] """A numpy array of vertex positions. Should have shape (V, 3). Synchronized automatically when assigned.""" - faces: onpt.NDArray[onp.uint32] + faces: npt.NDArray[np.uint32] """A numpy array of faces, where each face is represented by indices of vertices. Should have shape (F, 3). Synchronized automatically when assigned.""" color: Optional[Tuple[int, int, int]] """Color of the mesh as RGB integers. Synchronized automatically when assigned.""" - vertex_colors: Optional[onpt.NDArray[onp.uint8]] + vertex_colors: Optional[npt.NDArray[np.uint8]] """Optional array of vertex colors. Synchronized automatically when assigned.""" wireframe: bool """Boolean indicating if the mesh should be rendered as a wireframe. Synchronized automatically when assigned.""" @@ -557,9 +557,9 @@ class SkinnedMeshProps(MeshProps): """Tuple of quaternions representing bone orientations. Synchronized automatically when assigned.""" bone_positions: Tuple[Tuple[float, float, float], ...] """Tuple of positions representing bone positions. Synchronized automatically when assigned.""" - skin_indices: onpt.NDArray[onp.uint16] + skin_indices: npt.NDArray[np.uint16] """Array of skin indices. Should have shape (V, 4). Synchronized automatically when assigned.""" - skin_weights: onpt.NDArray[onp.float32] + skin_weights: npt.NDArray[np.float32] """Array of skin weights. Should have shape (V, 4). Synchronized automatically when assigned.""" def __post_init__(self): @@ -1241,7 +1241,7 @@ class GaussianSplatsMessage(Message, tag="SceneNodeMessage"): class GaussianSplatsProps: # Memory layout is borrowed from: # https://github.com/antimatter15/splat - buffer: onpt.NDArray[onp.uint32] + buffer: npt.NDArray[np.uint32] """Our buffer will contain: - x as f32 - y as f32 diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 0cf6a78ac..f2acef204 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Callable, Tuple, TypeVar, Union, cast, get_args import imageio.v3 as iio -import numpy as onp +import numpy as np from typing_extensions import Literal, ParamSpec, TypeAlias, assert_never from . import _messages @@ -54,15 +54,15 @@ RgbTupleOrArray: TypeAlias = Union[ - Tuple[int, int, int], Tuple[float, float, float], onp.ndarray + Tuple[int, int, int], Tuple[float, float, float], np.ndarray ] def _encode_rgb(rgb: RgbTupleOrArray) -> tuple[int, int, int]: - if isinstance(rgb, onp.ndarray): + if isinstance(rgb, np.ndarray): assert rgb.shape == (3,) rgb_fixed = tuple( - int(value) if onp.issubdtype(type(value), onp.integer) else int(value * 255) + int(value) if np.issubdtype(type(value), np.integer) else int(value * 255) for value in rgb ) assert len(rgb_fixed) == 3 @@ -70,7 +70,7 @@ def _encode_rgb(rgb: RgbTupleOrArray) -> tuple[int, int, int]: def _encode_image_binary( - image: onp.ndarray, + image: np.ndarray, format: Literal["png", "jpeg"], jpeg_quality: int | None = None, ) -> tuple[Literal["image/png", "image/jpeg"], bytes]: @@ -97,9 +97,9 @@ def _encode_image_binary( TVector = TypeVar("TVector", bound=tuple) -def cast_vector(vector: TVector | onp.ndarray, length: int) -> TVector: +def cast_vector(vector: TVector | np.ndarray, length: int) -> TVector: if not isinstance(vector, tuple): - assert cast(onp.ndarray, vector).shape == ( + assert cast(np.ndarray, vector).shape == ( length, ), f"Expected vector of shape {(length,)}, but got {vector.shape} instead" return cast(TVector, tuple(map(float, vector))) @@ -165,7 +165,7 @@ def set_up_direction( self, direction: Literal["+x", "+y", "+z", "-x", "-y", "-z"] | tuple[float, float, float] - | onp.ndarray, + | np.ndarray, ) -> None: """Set the global up direction of the scene. By default we follow +Z-up (similar to Blender, 3DS Max, ROS, etc), the most common alternative is @@ -186,20 +186,20 @@ def set_up_direction( }[direction] assert not isinstance(direction, str) - default_three_up = onp.array([0.0, 1.0, 0.0]) - direction = onp.asarray(direction) + default_three_up = np.array([0.0, 1.0, 0.0]) + direction = np.asarray(direction) - def rotate_between(before: onp.ndarray, after: onp.ndarray) -> tf.SO3: + def rotate_between(before: np.ndarray, after: np.ndarray) -> tf.SO3: assert before.shape == after.shape == (3,) - before = before / onp.linalg.norm(before) - after = after / onp.linalg.norm(after) - - angle = onp.arccos(onp.clip(onp.dot(before, after), -1, 1)) - axis = onp.cross(before, after) - if onp.allclose(axis, onp.zeros(3), rtol=1e-3, atol=1e-5): - unit_vector = onp.arange(3) == onp.argmin(onp.abs(before)) - axis = onp.cross(before, unit_vector) - axis = axis / onp.linalg.norm(axis) + before = before / np.linalg.norm(before) + after = after / np.linalg.norm(after) + + angle = np.arccos(np.clip(np.dot(before, after), -1, 1)) + axis = np.cross(before, after) + if np.allclose(axis, np.zeros(3), rtol=1e-3, atol=1e-5): + unit_vector = np.arange(3) == np.argmin(np.abs(before)) + axis = np.cross(before, unit_vector) + axis = axis / np.linalg.norm(axis) return tf.SO3.exp(angle * axis) R_threeworld_world = rotate_between(direction, default_three_up) @@ -209,21 +209,21 @@ def rotate_between(before: onp.ndarray, after: onp.ndarray) -> tf.SO3: # If we set +Z to up, +X and +Y should face the camera. # In App.tsx, the camera is initialized at [-3, 3, -3] in the threejs # coordinate frame. - desired_fwd = onp.array([-1.0, 0.0, -1.0]) / onp.sqrt(2.0) - current_fwd = R_threeworld_world @ (onp.ones(3) / onp.sqrt(3.0)) - current_fwd = current_fwd * onp.array([1.0, 0.0, 1.0]) - current_fwd = current_fwd / onp.linalg.norm(current_fwd) + desired_fwd = np.array([-1.0, 0.0, -1.0]) / np.sqrt(2.0) + current_fwd = R_threeworld_world @ (np.ones(3) / np.sqrt(3.0)) + current_fwd = current_fwd * np.array([1.0, 0.0, 1.0]) + current_fwd = current_fwd / np.linalg.norm(current_fwd) R_threeworld_world = ( tf.SO3.from_y_radians( # Rotate around the null space / up direction. - onp.arctan2( - onp.cross(current_fwd, desired_fwd)[1], - onp.dot(current_fwd, desired_fwd), + np.arctan2( + np.cross(current_fwd, desired_fwd)[1], + np.dot(current_fwd, desired_fwd), ), ) @ R_threeworld_world ) - if not onp.any(onp.isnan(R_threeworld_world.wxyz)): + if not np.any(np.isnan(R_threeworld_world.wxyz)): # Set the orientation of the root node. self._websock_interface.queue_message( _messages.SetOrientationMessage( @@ -251,7 +251,7 @@ def add_light_directional( name: str, color: Tuple[int, int, int] = (255, 255, 255), intensity: float = 1.0, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] = (0.0, 0.0, 0.0), visible: bool = True, ) -> DirectionalLightHandle: @@ -283,8 +283,8 @@ def add_light_ambient( name: str, color: Tuple[int, int, int] = (255, 255, 255), intensity: float = 1.0, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> AmbientLightHandle: """ @@ -314,7 +314,7 @@ def add_light_hemisphere( sky_color: Tuple[int, int, int] = (255, 255, 255), ground_color: Tuple[int, int, int] = (255, 255, 255), intensity: float = 1.0, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] = (0.0, 0.0, 0.0), visible: bool = True, ) -> HemisphereLightHandle: @@ -347,7 +347,7 @@ def add_light_point( intensity: float = 1.0, distance: float = 0.0, decay: float = 2.0, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] = (0.0, 0.0, 0.0), visible: bool = True, ) -> PointLightHandle: @@ -387,7 +387,7 @@ def add_light_rectarea( intensity: float = 1.0, width: float = 10.0, height: float = 10.0, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] = (0.0, 0.0, 0.0), visible: bool = True, ) -> RectAreaLightHandle: @@ -425,11 +425,11 @@ def add_light_spot( name: str, color: Tuple[int, int, int] = (255, 255, 255), distance: float = 0.0, - angle: float = onp.pi / 3, + angle: float = np.pi / 3, penumbra: float = 0.0, decay: float = 2.0, intensity: float = 1.0, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), position: tuple[float, float, float] = (0.0, 0.0, 0.0), visible: bool = True, ) -> SpotLightHandle: @@ -526,8 +526,8 @@ def add_glb( name: str, glb_data: bytes, scale: float = 1.0, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> GlbHandle: """Add a general 3D asset via binary glTF (GLB). @@ -556,15 +556,15 @@ def add_glb( def add_spline_catmull_rom( self, name: str, - positions: tuple[tuple[float, float, float], ...] | onp.ndarray, + positions: tuple[tuple[float, float, float], ...] | np.ndarray, curve_type: Literal["centripetal", "chordal", "catmullrom"] = "centripetal", tension: float = 0.5, closed: bool = False, line_width: float = 1, color: RgbTupleOrArray = (20, 20, 20), segments: int | None = None, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> SplineCatmullRomHandle: """Add a spline to the scene using Catmull-Rom interpolation. @@ -589,7 +589,7 @@ def add_spline_catmull_rom( Returns: Handle for manipulating scene node. """ - if isinstance(positions, onp.ndarray): + if isinstance(positions, np.ndarray): assert len(positions.shape) == 2 and positions.shape[1] == 3 positions = tuple(map(tuple, positions)) # type: ignore assert len(positions[0]) == 3 @@ -613,13 +613,13 @@ def add_spline_catmull_rom( def add_spline_cubic_bezier( self, name: str, - positions: tuple[tuple[float, float, float], ...] | onp.ndarray, - control_points: tuple[tuple[float, float, float], ...] | onp.ndarray, + positions: tuple[tuple[float, float, float], ...] | np.ndarray, + control_points: tuple[tuple[float, float, float], ...] | np.ndarray, line_width: float = 1.0, color: RgbTupleOrArray = (20, 20, 20), segments: int | None = None, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> SplineCubicBezierHandle: """Add a spline to the scene using Cubic Bezier interpolation. @@ -644,10 +644,10 @@ def add_spline_cubic_bezier( Handle for manipulating scene node. """ - if isinstance(positions, onp.ndarray): + if isinstance(positions, np.ndarray): assert len(positions.shape) == 2 and positions.shape[1] == 3 positions = tuple(map(tuple, positions)) # type: ignore - if isinstance(control_points, onp.ndarray): + if isinstance(control_points, np.ndarray): assert len(control_points.shape) == 2 and control_points.shape[1] == 3 control_points = tuple(map(tuple, control_points)) # type: ignore @@ -675,11 +675,11 @@ def add_camera_frustum( aspect: float, scale: float = 0.3, color: RgbTupleOrArray = (20, 20, 20), - image: onp.ndarray | None = None, + image: np.ndarray | None = None, format: Literal["png", "jpeg"] = "jpeg", jpeg_quality: int | None = None, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> CameraFrustumHandle: """Add a camera frustum to the scene for visualization. @@ -737,8 +737,8 @@ def add_frame( axes_length: float = 0.5, axes_radius: float = 0.025, origin_radius: float | None = None, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> FrameHandle: """Add a coordinate frame to the scene. @@ -782,12 +782,12 @@ def add_frame( def add_batched_axes( self, name: str, - batched_wxyzs: tuple[tuple[float, float, float, float], ...] | onp.ndarray, - batched_positions: tuple[tuple[float, float, float], ...] | onp.ndarray, + batched_wxyzs: tuple[tuple[float, float, float, float], ...] | np.ndarray, + batched_positions: tuple[tuple[float, float, float], ...] | np.ndarray, axes_length: float = 0.5, axes_radius: float = 0.025, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> BatchedAxesHandle: """Visualize batched sets of coordinate frame axes. @@ -818,15 +818,15 @@ def add_batched_axes( Returns: Handle for manipulating scene node. """ - batched_wxyzs = onp.asarray(batched_wxyzs) - batched_positions = onp.asarray(batched_positions) + batched_wxyzs = np.asarray(batched_wxyzs) + batched_positions = np.asarray(batched_positions) num_axes = batched_wxyzs.shape[0] assert batched_wxyzs.shape == (num_axes, 4) assert batched_positions.shape == (num_axes, 3) props = _messages.BatchedAxesProps( - wxyzs_batched=batched_wxyzs.astype(onp.float32), - positions_batched=batched_positions.astype(onp.float32), + wxyzs_batched=batched_wxyzs.astype(np.float32), + positions_batched=batched_positions.astype(np.float32), axes_length=axes_length, axes_radius=axes_radius, ) @@ -850,8 +850,8 @@ def add_grid( section_color: RgbTupleOrArray = (140, 140, 140), section_thickness: float = 1.0, section_size: float = 1.0, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> GridHandle: """Add a 2D grid to the scene. @@ -900,8 +900,8 @@ def add_label( self, name: str, text: str, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> LabelHandle: """Add a 2D label to the scene. @@ -925,14 +925,14 @@ def add_label( def add_point_cloud( self, name: str, - points: onp.ndarray, - colors: onp.ndarray | tuple[float, float, float], + points: np.ndarray, + colors: np.ndarray | tuple[float, float, float], point_size: float = 0.1, point_shape: Literal[ "square", "diamond", "circle", "rounded", "sparkle" ] = "square", - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> PointCloudHandle: """Add a point cloud to the scene. @@ -950,7 +950,7 @@ def add_point_cloud( Returns: Handle for manipulating scene node. """ - colors_cast = colors_to_uint8(onp.asarray(colors)) + colors_cast = colors_to_uint8(np.asarray(colors)) assert ( len(points.shape) == 2 and points.shape[-1] == 3 ), "Shape of points should be (N, 3)." @@ -960,12 +960,12 @@ def add_point_cloud( }, "Shape of colors should be (N, 3) or (3,)." if colors_cast.shape == (3,): - colors_cast = onp.tile(colors_cast[None, :], reps=(points.shape[0], 1)) + colors_cast = np.tile(colors_cast[None, :], reps=(points.shape[0], 1)) message = _messages.PointCloudMessage( name=name, props=_messages.PointCloudProps( - points=points.astype(onp.float32), + points=points.astype(np.float32), colors=colors_cast, point_size=point_size, point_ball_norm={ @@ -982,19 +982,19 @@ def add_point_cloud( def add_mesh_skinned( self, name: str, - vertices: onp.ndarray, - faces: onp.ndarray, - bone_wxyzs: tuple[tuple[float, float, float, float], ...] | onp.ndarray, - bone_positions: tuple[tuple[float, float, float], ...] | onp.ndarray, - skin_weights: onp.ndarray, + vertices: np.ndarray, + faces: np.ndarray, + bone_wxyzs: tuple[tuple[float, float, float, float], ...] | np.ndarray, + bone_positions: tuple[tuple[float, float, float], ...] | np.ndarray, + skin_weights: np.ndarray, color: RgbTupleOrArray = (90, 200, 255), wireframe: bool = False, opacity: float | None = None, material: Literal["standard", "toon3", "toon5"] = "standard", flat_shading: bool = False, side: Literal["front", "back", "double"] = "front", - wxyz: Tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: Tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: Tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: Tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> MeshSkinnedHandle: """Add a skinned mesh to the scene, which we can deform using a set of @@ -1041,23 +1041,23 @@ def add_mesh_skinned( assert skin_weights.shape == (vertices.shape[0], num_bones) # Take the four biggest indices. - top4_skin_indices = onp.argsort(skin_weights, axis=-1)[:, -4:] + top4_skin_indices = np.argsort(skin_weights, axis=-1)[:, -4:] top4_skin_weights = skin_weights[ - onp.arange(vertices.shape[0])[:, None], top4_skin_indices + np.arange(vertices.shape[0])[:, None], top4_skin_indices ] assert ( top4_skin_weights.shape == top4_skin_indices.shape == (vertices.shape[0], 4) ) - bone_wxyzs = onp.asarray(bone_wxyzs) - bone_positions = onp.asarray(bone_positions) + bone_wxyzs = np.asarray(bone_wxyzs) + bone_positions = np.asarray(bone_positions) assert bone_wxyzs.shape == (num_bones, 4) assert bone_positions.shape == (num_bones, 3) message = _messages.SkinnedMeshMessage( name=name, props=_messages.SkinnedMeshProps( - vertices=vertices.astype(onp.float32), - faces=faces.astype(onp.uint32), + vertices=vertices.astype(np.float32), + faces=faces.astype(np.uint32), color=_encode_rgb(color), vertex_colors=None, wireframe=wireframe, @@ -1072,14 +1072,14 @@ def add_mesh_skinned( float(wxyz[2]), float(wxyz[3]), ) - for wxyz in bone_wxyzs.astype(onp.float32) + for wxyz in bone_wxyzs.astype(np.float32) ), bone_positions=tuple( (float(xyz[0]), float(xyz[1]), float(xyz[2])) - for xyz in bone_positions.astype(onp.float32) + for xyz in bone_positions.astype(np.float32) ), - skin_indices=top4_skin_indices.astype(onp.uint16), - skin_weights=top4_skin_weights.astype(onp.float32), + skin_indices=top4_skin_indices.astype(np.uint16), + skin_weights=top4_skin_weights.astype(np.float32), ), ) handle = MeshHandle._make(self, message, name, wxyz, position, visible) @@ -1102,16 +1102,16 @@ def add_mesh_skinned( def add_mesh_simple( self, name: str, - vertices: onp.ndarray, - faces: onp.ndarray, + vertices: np.ndarray, + faces: np.ndarray, color: RgbTupleOrArray = (90, 200, 255), wireframe: bool = False, opacity: float | None = None, material: Literal["standard", "toon3", "toon5"] = "standard", flat_shading: bool = False, side: Literal["front", "back", "double"] = "front", - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> MeshHandle: """Add a mesh to the scene. @@ -1150,8 +1150,8 @@ def add_mesh_simple( message = _messages.MeshMessage( name=name, props=_messages.MeshProps( - vertices=vertices.astype(onp.float32), - faces=faces.astype(onp.uint32), + vertices=vertices.astype(np.float32), + faces=faces.astype(np.uint32), color=_encode_rgb(color), vertex_colors=None, wireframe=wireframe, @@ -1168,8 +1168,8 @@ def add_mesh_trimesh( name: str, mesh: trimesh.Trimesh, scale: float = 1.0, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> GlbHandle: """Add a trimesh mesh to the scene. Internally calls `self.add_glb()`. @@ -1202,12 +1202,12 @@ def add_mesh_trimesh( def _add_gaussian_splats( self, name: str, - centers: onp.ndarray, - covariances: onp.ndarray, - rgbs: onp.ndarray, - opacities: onp.ndarray, - wxyz: Tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: Tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + centers: np.ndarray, + covariances: np.ndarray, + rgbs: np.ndarray, + opacities: np.ndarray, + wxyz: Tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: Tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> GaussianSplatHandle: """Add a model to render using Gaussian Splatting. @@ -1237,26 +1237,26 @@ def _add_gaussian_splats( # Get cholesky factor of covariance. This helps retain precision when # we convert to float16. cov_cholesky_triu = ( - onp.linalg.cholesky(covariances.astype(onp.float64) + onp.ones(3) * 1e-7) + np.linalg.cholesky(covariances.astype(np.float64) + np.ones(3) * 1e-7) .swapaxes(-1, -2) # tril => triu - .reshape((-1, 9))[:, onp.array([0, 1, 2, 4, 5, 8])] + .reshape((-1, 9))[:, np.array([0, 1, 2, 4, 5, 8])] ) - buffer = onp.concatenate( + buffer = np.concatenate( [ # First texelFetch. # - xyz (96 bits): centers. - centers.astype(onp.float32).view(onp.uint8), + centers.astype(np.float32).view(np.uint8), # - w (32 bits): this is reserved for use by the renderer. - onp.zeros((num_gaussians, 4), dtype=onp.uint8), + np.zeros((num_gaussians, 4), dtype=np.uint8), # Second texelFetch. # - xyz (96 bits): upper-triangular Cholesky factor of covariance. - cov_cholesky_triu.astype(onp.float16).copy().view(onp.uint8), + cov_cholesky_triu.astype(np.float16).copy().view(np.uint8), # - w (32 bits): rgba. colors_to_uint8(rgbs), colors_to_uint8(opacities), ], axis=-1, - ).view(onp.uint32) + ).view(np.uint32) assert buffer.shape == (num_gaussians, 8) message = _messages.GaussianSplatsMessage( @@ -1274,9 +1274,9 @@ def add_box( self, name: str, color: RgbTupleOrArray, - dimensions: tuple[float, float, float] | onp.ndarray = (1.0, 1.0, 1.0), - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + dimensions: tuple[float, float, float] | np.ndarray = (1.0, 1.0, 1.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> MeshHandle: """Add a box to the scene. @@ -1314,8 +1314,8 @@ def add_icosphere( radius: float, color: RgbTupleOrArray, subdivisions: int = 3, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> MeshHandle: """Add an icosphere to the scene. @@ -1352,10 +1352,10 @@ def add_icosphere( def set_background_image( self, - image: onp.ndarray, + image: np.ndarray, format: Literal["png", "jpeg"] = "jpeg", jpeg_quality: int | None = None, - depth: onp.ndarray | None = None, + depth: np.ndarray | None = None, ) -> None: """Set a background image for the scene, optionally with depth compositing. @@ -1380,9 +1380,9 @@ def set_background_image( assert len(depth.shape) == 2 or ( len(depth.shape) == 3 and depth.shape[2] == 1 ), "Depth should have shape (H,W) or (H,W,1)." - depth = onp.clip(depth * 100_000, 0, 2**24 - 1).astype(onp.uint32) + depth = np.clip(depth * 100_000, 0, 2**24 - 1).astype(np.uint32) assert depth is not None # Appease mypy. - intdepth: onp.ndarray = depth.reshape((*depth.shape[:2], 1)).view(onp.uint8) + intdepth: np.ndarray = depth.reshape((*depth.shape[:2], 1)).view(np.uint8) assert intdepth.shape == (*depth.shape[:2], 4) with io.BytesIO() as data_buffer: iio.imwrite(data_buffer, intdepth[:, :, :3], extension=".png") @@ -1399,13 +1399,13 @@ def set_background_image( def add_image( self, name: str, - image: onp.ndarray, + image: np.ndarray, render_width: float, render_height: float, format: Literal["png", "jpeg"] = "jpeg", jpeg_quality: int | None = None, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> ImageHandle: """Add a 2D image to the scene. @@ -1458,8 +1458,8 @@ def add_transform_controls( ] = ((-1000.0, 1000.0), (-1000.0, 1000.0), (-1000.0, 1000.0)), depth_test: bool = True, opacity: float = 1.0, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> TransformControlsHandle: """Add a transform gizmo for interacting with the scene. @@ -1565,8 +1565,8 @@ def _handle_transform_controls_updates( return # Update state. - wxyz = onp.array(message.wxyz) - position = onp.array(message.position) + wxyz = np.array(message.wxyz) + position = np.array(message.position) with self._owner.atomic(): handle._impl.wxyz = wxyz handle._impl.position = position @@ -1714,8 +1714,8 @@ def remove_pointer_callback( def add_3d_gui_container( self, name: str, - wxyz: tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), - position: tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), + wxyz: tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), + position: tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> Gui3dContainerHandle: """Add a 3D gui container to the scene. The returned container handle can be diff --git a/src/viser/_scene_handles.py b/src/viser/_scene_handles.py index f5ba05612..87c8b990b 100644 --- a/src/viser/_scene_handles.py +++ b/src/viser/_scene_handles.py @@ -15,7 +15,7 @@ cast, ) -import numpy as onp +import numpy as np import numpy.typing as onpt from typing_extensions import get_type_hints @@ -29,14 +29,14 @@ from .infra import ClientId -def colors_to_uint8(colors: onp.ndarray) -> onpt.NDArray[onp.uint8]: +def colors_to_uint8(colors: np.ndarray) -> onpt.NDArray[np.uint8]: """Convert intensity values to uint8. We assume the range [0,1] for floats, and [0,255] for integers. Accepts any shape.""" - if colors.dtype != onp.uint8: - if onp.issubdtype(colors.dtype, onp.floating): - colors = onp.clip(colors * 255.0, 0, 255).astype(onp.uint8) - if onp.issubdtype(colors.dtype, onp.integer): - colors = onp.clip(colors, 0, 255).astype(onp.uint8) + if colors.dtype != np.uint8: + if np.issubdtype(colors.dtype, np.floating): + colors = np.clip(colors * 255.0, 0, 255).astype(np.uint8) + if np.issubdtype(colors.dtype, np.integer): + colors = np.clip(colors, 0, 255).astype(np.uint8) return colors @@ -52,9 +52,9 @@ def __setattr__(self, name: str, value: Any) -> None: if name in self._prop_hints: # Help the user with some casting... hint = self._prop_hints[name] - if hint == onpt.NDArray[onp.float32]: - value = value.astype(onp.float32) - elif hint == onpt.NDArray[onp.uint8] and "color" in name: + if hint == onpt.NDArray[np.float32]: + value = value.astype(np.float32) + elif hint == onpt.NDArray[np.uint8] and "color" in name: value = colors_to_uint8(value) setattr(handle._impl.props, name, value) @@ -112,11 +112,11 @@ class _SceneNodeHandleState: """Message containing properties of this scene node that are sent to the client.""" api: SceneApi - wxyz: onp.ndarray = dataclasses.field( - default_factory=lambda: onp.array([1.0, 0.0, 0.0, 0.0]) + wxyz: np.ndarray = dataclasses.field( + default_factory=lambda: np.array([1.0, 0.0, 0.0, 0.0]) ) - position: onp.ndarray = dataclasses.field( - default_factory=lambda: onp.array([0.0, 0.0, 0.0]) + position: np.ndarray = dataclasses.field( + default_factory=lambda: np.array([0.0, 0.0, 0.0]) ) visible: bool = True # TODO: we should remove SceneNodeHandle as an argument here. @@ -147,8 +147,8 @@ def _make( api: SceneApi, message: _SceneNodeMessage, name: str, - wxyz: tuple[float, float, float, float] | onp.ndarray, - position: tuple[float, float, float] | onp.ndarray, + wxyz: tuple[float, float, float, float] | np.ndarray, + position: tuple[float, float, float] | np.ndarray, visible: bool, ) -> TSceneNodeHandle: """Create scene node: send state to client(s) and set up @@ -170,35 +170,35 @@ def _make( return out @property - def wxyz(self) -> onp.ndarray: + def wxyz(self) -> np.ndarray: """Orientation of the scene node. This is the quaternion representation of the R in `p_parent = [R | t] p_local`. Synchronized to clients automatically when assigned. """ return self._impl.wxyz @wxyz.setter - def wxyz(self, wxyz: tuple[float, float, float, float] | onp.ndarray) -> None: + def wxyz(self, wxyz: tuple[float, float, float, float] | np.ndarray) -> None: from ._scene_api import cast_vector wxyz_cast = cast_vector(wxyz, 4) - self._impl.wxyz = onp.asarray(wxyz) + self._impl.wxyz = np.asarray(wxyz) self._impl.api._websock_interface.queue_message( _messages.SetOrientationMessage(self._impl.name, wxyz_cast) ) @property - def position(self) -> onp.ndarray: + def position(self) -> np.ndarray: """Position of the scene node. This is equivalent to the t in `p_parent = [R | t] p_local`. Synchronized to clients automatically when assigned. """ return self._impl.position @position.setter - def position(self, position: tuple[float, float, float] | onp.ndarray) -> None: + def position(self, position: tuple[float, float, float] | np.ndarray) -> None: from ._scene_api import cast_vector position_cast = cast_vector(position, 3) - self._impl.position = onp.asarray(position) + self._impl.position = np.asarray(position) self._impl.api._websock_interface.queue_message( _messages.SetPositionMessage(self._impl.name, position_cast) ) @@ -380,8 +380,8 @@ class BoneState: name: str websock_interface: WebsockServer | WebsockClientConnection bone_index: int - wxyz: onp.ndarray - position: onp.ndarray + wxyz: np.ndarray + position: np.ndarray @dataclasses.dataclass @@ -391,18 +391,18 @@ class MeshSkinnedBoneHandle: _impl: BoneState @property - def wxyz(self) -> onp.ndarray: + def wxyz(self) -> np.ndarray: """Orientation of the bone. This is the quaternion representation of the R in `p_parent = [R | t] p_local`. Synchronized to clients automatically when assigned. """ return self._impl.wxyz @wxyz.setter - def wxyz(self, wxyz: tuple[float, float, float, float] | onp.ndarray) -> None: + def wxyz(self, wxyz: tuple[float, float, float, float] | np.ndarray) -> None: from ._scene_api import cast_vector wxyz_cast = cast_vector(wxyz, 4) - self._impl.wxyz = onp.asarray(wxyz) + self._impl.wxyz = np.asarray(wxyz) self._impl.websock_interface.queue_message( _messages.SetBoneOrientationMessage( self._impl.name, self._impl.bone_index, wxyz_cast @@ -410,18 +410,18 @@ def wxyz(self, wxyz: tuple[float, float, float, float] | onp.ndarray) -> None: ) @property - def position(self) -> onp.ndarray: + def position(self) -> np.ndarray: """Position of the bone. This is equivalent to the t in `p_parent = [R | t] p_local`. Synchronized to clients automatically when assigned. """ return self._impl.position @position.setter - def position(self, position: tuple[float, float, float] | onp.ndarray) -> None: + def position(self, position: tuple[float, float, float] | np.ndarray) -> None: from ._scene_api import cast_vector position_cast = cast_vector(position, 3) - self._impl.position = onp.asarray(position) + self._impl.position = np.asarray(position) self._impl.websock_interface.queue_message( _messages.SetBonePositionMessage( self._impl.name, self._impl.bone_index, position_cast diff --git a/src/viser/_viser.py b/src/viser/_viser.py index e56110fc4..2603d1e4f 100644 --- a/src/viser/_viser.py +++ b/src/viser/_viser.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Callable, ContextManager import imageio.v3 as iio -import numpy as onp +import numpy as np import numpy.typing as npt import rich from rich import box, style @@ -68,12 +68,12 @@ class _CameraHandleState: """Information about a client's camera state.""" client: ClientHandle - wxyz: npt.NDArray[onp.float64] - position: npt.NDArray[onp.float64] + wxyz: npt.NDArray[np.float64] + position: npt.NDArray[np.float64] fov: float aspect: float - look_at: npt.NDArray[onp.float64] - up_direction: npt.NDArray[onp.float64] + look_at: npt.NDArray[np.float64] + up_direction: npt.NDArray[np.float64] update_timestamp: float camera_cb: list[Callable[[CameraHandle], None]] @@ -85,12 +85,12 @@ class CameraHandle: def __init__(self, client: ClientHandle) -> None: self._state = _CameraHandleState( client, - wxyz=onp.zeros(4), - position=onp.zeros(3), + wxyz=np.zeros(4), + position=np.zeros(3), fov=0.0, aspect=0.0, - look_at=onp.zeros(3), - up_direction=onp.zeros(3), + look_at=np.zeros(3), + up_direction=np.zeros(3), update_timestamp=0.0, camera_cb=[], ) @@ -101,7 +101,7 @@ def client(self) -> ClientHandle: return self._state.client @property - def wxyz(self) -> npt.NDArray[onp.float64]: + def wxyz(self) -> npt.NDArray[np.float64]: """Corresponds to the R in `P_world = [R | t] p_camera`. Synchronized automatically when assigned.""" assert self._state.update_timestamp != 0.0 @@ -111,9 +111,9 @@ def wxyz(self) -> npt.NDArray[onp.float64]: # - https://github.com/python/mypy/issues/3004 # - https://github.com/python/mypy/pull/11643 @wxyz.setter - def wxyz(self, wxyz: tuple[float, float, float, float] | onp.ndarray) -> None: - R_world_camera = tf.SO3(onp.asarray(wxyz)).as_matrix() - look_distance = onp.linalg.norm(self.look_at - self.position) + def wxyz(self, wxyz: tuple[float, float, float, float] | np.ndarray) -> None: + R_world_camera = tf.SO3(np.asarray(wxyz)).as_matrix() + look_distance = np.linalg.norm(self.look_at - self.position) # We're following OpenCV conventions: look_direction is +Z, up_direction is -Y, # right_direction is +X. @@ -141,12 +141,12 @@ def wxyz(self, wxyz: tuple[float, float, float, float] | onp.ndarray) -> None: # The internal camera orientation should be set in the look_at / # up_direction setters. We can uncomment this assert to check this. - # assert onp.allclose(self._state.wxyz, wxyz) or onp.allclose( + # assert np.allclose(self._state.wxyz, wxyz) or np.allclose( # self._state.wxyz, -wxyz # ) @property - def position(self) -> npt.NDArray[onp.float64]: + def position(self) -> npt.NDArray[np.float64]: """Corresponds to the t in `P_world = [R | t] p_camera`. Synchronized automatically when assigned. @@ -157,10 +157,10 @@ def position(self) -> npt.NDArray[onp.float64]: return self._state.position @position.setter - def position(self, position: tuple[float, float, float] | onp.ndarray) -> None: - offset = onp.asarray(position) - onp.array(self.position) # type: ignore - self._state.position = onp.asarray(position) - self.look_at = onp.array(self.look_at) + offset + def position(self, position: tuple[float, float, float] | np.ndarray) -> None: + offset = np.asarray(position) - np.array(self.position) # type: ignore + self._state.position = np.asarray(position) + self.look_at = np.array(self.look_at) + offset self._state.update_timestamp = time.time() self._state.client._websock_connection.queue_message( _messages.SetCameraPositionMessage(cast_vector(position, 3)) @@ -169,12 +169,12 @@ def position(self, position: tuple[float, float, float] | onp.ndarray) -> None: def _update_wxyz(self) -> None: """Compute and update the camera orientation from the internal look_at, position, and up vectors.""" z = self._state.look_at - self._state.position - z /= onp.linalg.norm(z) - y = tf.SO3.exp(z * onp.pi) @ self._state.up_direction - y = y - onp.dot(z, y) * z - y /= onp.linalg.norm(y) - x = onp.cross(y, z) - self._state.wxyz = tf.SO3.from_matrix(onp.stack([x, y, z], axis=1)).wxyz + z /= np.linalg.norm(z) + y = tf.SO3.exp(z * np.pi) @ self._state.up_direction + y = y - np.dot(z, y) * z + y /= np.linalg.norm(y) + x = np.cross(y, z) + self._state.wxyz = tf.SO3.from_matrix(np.stack([x, y, z], axis=1)).wxyz @property def fov(self) -> float: @@ -203,14 +203,14 @@ def update_timestamp(self) -> float: return self._state.update_timestamp @property - def look_at(self) -> npt.NDArray[onp.float64]: + def look_at(self) -> npt.NDArray[np.float64]: """Look at point for the camera. Synchronized automatically when set.""" assert self._state.update_timestamp != 0.0 return self._state.look_at @look_at.setter - def look_at(self, look_at: tuple[float, float, float] | onp.ndarray) -> None: - self._state.look_at = onp.asarray(look_at) + def look_at(self, look_at: tuple[float, float, float] | np.ndarray) -> None: + self._state.look_at = np.asarray(look_at) self._state.update_timestamp = time.time() self._update_wxyz() self._state.client._websock_connection.queue_message( @@ -218,16 +218,16 @@ def look_at(self, look_at: tuple[float, float, float] | onp.ndarray) -> None: ) @property - def up_direction(self) -> npt.NDArray[onp.float64]: + def up_direction(self) -> npt.NDArray[np.float64]: """Up direction for the camera. Synchronized automatically when set.""" assert self._state.update_timestamp != 0.0 return self._state.up_direction @up_direction.setter def up_direction( - self, up_direction: tuple[float, float, float] | onp.ndarray + self, up_direction: tuple[float, float, float] | np.ndarray ) -> None: - self._state.up_direction = onp.asarray(up_direction) + self._state.up_direction = np.asarray(up_direction) self._update_wxyz() self._state.update_timestamp = time.time() self._state.client._websock_connection.queue_message( @@ -243,7 +243,7 @@ def on_update( def get_render( self, height: int, width: int, transport_format: Literal["png", "jpeg"] = "jpeg" - ) -> onp.ndarray: + ) -> np.ndarray: """Request a render from a client, block until it's done and received, then return it as a numpy array. @@ -258,7 +258,7 @@ def get_render( # Listen for a render reseponse message, which should contain the rendered # image. render_ready_event = threading.Event() - out: onp.ndarray | None = None + out: np.ndarray | None = None connection = self.client._websock_connection @@ -363,7 +363,7 @@ def send_file_download( parts = [ content[i * chunk_size : (i + 1) * chunk_size] - for i in range(int(onp.ceil(len(content) / chunk_size))) + for i in range(int(np.ceil(len(content) / chunk_size))) ] uuid = _make_unique_id() @@ -500,12 +500,12 @@ def handle_camera_message( with client.atomic(): client.camera._state = _CameraHandleState( client, - onp.array(message.wxyz), - onp.array(message.position), + np.array(message.wxyz), + np.array(message.position), message.fov, message.aspect, - onp.array(message.look_at), - onp.array(message.up_direction), + np.array(message.look_at), + np.array(message.up_direction), time.time(), camera_cb=client.camera._state.camera_cb, ) diff --git a/src/viser/extras/_record3d.py b/src/viser/extras/_record3d.py index 1907298f0..0efd94ef4 100644 --- a/src/viser/extras/_record3d.py +++ b/src/viser/extras/_record3d.py @@ -8,8 +8,7 @@ import imageio.v3 as iio import liblzfse import numpy as np -import numpy as onp -import numpy.typing as onpt +import numpy.typing as npt import skimage.transform from scipy.spatial.transform import Rotation @@ -26,10 +25,10 @@ def __init__(self, data_dir: Path): # Read metadata. metadata = json.loads(metadata_path.read_text()) - K: onp.ndarray = np.array(metadata["K"], np.float32).reshape(3, 3).T + K: np.ndarray = np.array(metadata["K"], np.float32).reshape(3, 3).T fps = metadata["fps"] - T_world_cameras: onp.ndarray = np.array(metadata["poses"], np.float32) + T_world_cameras: np.ndarray = np.array(metadata["poses"], np.float32) T_world_cameras = np.concatenate( [ Rotation.from_quat(T_world_cameras[:, :4]).as_matrix(), @@ -55,7 +54,7 @@ def num_frames(self) -> int: def get_frame(self, index: int) -> Record3dFrame: # Read conf. - conf: onp.ndarray = np.frombuffer( + conf: np.ndarray = np.frombuffer( liblzfse.decompress(self.conf_paths[index].read_bytes()), dtype=np.uint8 ) if conf.shape[0] == 640 * 480: @@ -66,7 +65,7 @@ def get_frame(self, index: int) -> Record3dFrame: assert False, f"Unexpected conf shape {conf.shape}" # Read depth. - depth: onp.ndarray = np.frombuffer( + depth: np.ndarray = np.frombuffer( liblzfse.decompress(self.depth_paths[index].read_bytes()), dtype=np.float32 ).copy() if depth.shape[0] == 640 * 480: @@ -91,19 +90,19 @@ def get_frame(self, index: int) -> Record3dFrame: class Record3dFrame: """A single frame from a Record3D capture.""" - K: onpt.NDArray[onp.float32] - rgb: onpt.NDArray[onp.uint8] - depth: onpt.NDArray[onp.float32] - mask: onpt.NDArray[onp.bool_] - T_world_camera: onpt.NDArray[onp.float32] + K: npt.NDArray[np.float32] + rgb: npt.NDArray[np.uint8] + depth: npt.NDArray[np.float32] + mask: npt.NDArray[np.bool_] + T_world_camera: npt.NDArray[np.float32] def get_point_cloud( self, downsample_factor: int = 1 - ) -> Tuple[onpt.NDArray[onp.float32], onpt.NDArray[onp.uint8]]: + ) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.uint8]]: rgb = self.rgb[::downsample_factor, ::downsample_factor] depth = skimage.transform.resize(self.depth, rgb.shape[:2], order=0) mask = cast( - onpt.NDArray[onp.bool_], + npt.NDArray[np.bool_], skimage.transform.resize(self.mask, rgb.shape[:2], order=0), ) assert depth.shape == rgb.shape[:2] diff --git a/src/viser/extras/_urdf.py b/src/viser/extras/_urdf.py index 4e7b2fbdc..7c475f840 100644 --- a/src/viser/extras/_urdf.py +++ b/src/viser/extras/_urdf.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import List, Tuple -import numpy as onp +import numpy as np import trimesh import yourdfpy @@ -103,7 +103,7 @@ def remove(self) -> None: for mesh in self._meshes: mesh.remove() - def update_cfg(self, configuration: onp.ndarray) -> None: + def update_cfg(self, configuration: np.ndarray) -> None: """Update the joint angles of the visualized URDF.""" self._urdf.update_cfg(configuration) with self._target.atomic(): @@ -126,7 +126,7 @@ def get_actuated_joint_limits( assert isinstance(joint_name, str) assert isinstance(joint, yourdfpy.Joint) if joint.limit is None: - out[joint_name] = (-onp.pi, onp.pi) + out[joint_name] = (-np.pi, np.pi) else: out[joint_name] = (joint.limit.lower, joint.limit.upper) return out diff --git a/src/viser/infra/_messages.py b/src/viser/infra/_messages.py index 9cfb50ba9..bfbbf9ff0 100644 --- a/src/viser/infra/_messages.py +++ b/src/viser/infra/_messages.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, cast import msgspec -import numpy as onp +import numpy as np from typing_extensions import get_args, get_origin, get_type_hints if TYPE_CHECKING: @@ -51,10 +51,10 @@ def _prepare_for_serialization(value: Any, annotation: object) -> Any: annotation = type(value) # Coerce some scalar types: if we've annotated as float / int but we get an - # onp.float32 / onp.int64, for example, we should cast automatically. - if annotation is float or isinstance(value, onp.floating): + # np.float32 / np.int64, for example, we should cast automatically. + if annotation is float or isinstance(value, np.floating): return float(value) - if annotation is int or isinstance(value, onp.integer): + if annotation is int or isinstance(value, np.integer): return int(value) if dataclasses.is_dataclass(annotation): @@ -62,7 +62,7 @@ def _prepare_for_serialization(value: Any, annotation: object) -> Any: # Recursively handle tuples. if isinstance(value, tuple): - if isinstance(value, onp.ndarray): + if isinstance(value, np.ndarray): assert False, ( "Expected a tuple, but got an array... missing a cast somewhere?" f" {value}" @@ -89,7 +89,7 @@ def _prepare_for_serialization(value: Any, annotation: object) -> Any: # For arrays, we serialize underlying data directly. The client is responsible for # reading using the correct dtype. - if isinstance(value, onp.ndarray): + if isinstance(value, np.ndarray): return value.data if value.data.c_contiguous else value.copy().data if isinstance(value, dict): diff --git a/src/viser/infra/_typescript_interface_gen.py b/src/viser/infra/_typescript_interface_gen.py index b0ea5f634..86ebb48ee 100644 --- a/src/viser/infra/_typescript_interface_gen.py +++ b/src/viser/infra/_typescript_interface_gen.py @@ -2,7 +2,7 @@ from collections import defaultdict from typing import Any, Type, Union, cast -import numpy as onp +import numpy as np from typing_extensions import ( Annotated, Literal, @@ -26,7 +26,7 @@ int: "number", str: "string", # For numpy arrays, we directly serialize the underlying data buffer. - onp.ndarray: "Uint8Array", + np.ndarray: "Uint8Array", bytes: "Uint8Array", Any: "any", None: "null", diff --git a/src/viser/transforms/_base.py b/src/viser/transforms/_base.py index 523a926cb..6da62cd2d 100644 --- a/src/viser/transforms/_base.py +++ b/src/viser/transforms/_base.py @@ -1,8 +1,8 @@ import abc from typing import ClassVar, Generic, Tuple, TypeVar, Union, overload -import numpy as onp -import numpy.typing as onpt +import numpy as np +import numpy.typing as npt from typing_extensions import Self, final, get_args, override @@ -28,7 +28,7 @@ def __init__( # - This method is implicitly overriden by the dataclass decorator and # should _not_ be marked abstract. self, - parameters: onp.ndarray, + parameters: np.ndarray, ): """Construct a group object from its underlying parameters.""" raise NotImplementedError() @@ -53,18 +53,18 @@ def __matmul__(self, other: Self) -> Self: ... @overload def __matmul__( - self, other: onpt.NDArray[onp.floating] - ) -> onpt.NDArray[onp.floating]: ... + self, other: npt.NDArray[np.floating] + ) -> npt.NDArray[np.floating]: ... def __matmul__( - self, other: Union[Self, onpt.NDArray[onp.floating]] - ) -> Union[Self, onpt.NDArray[onp.floating]]: + self, other: Union[Self, npt.NDArray[np.floating]] + ) -> Union[Self, npt.NDArray[np.floating]]: """Overload for the `@` operator. Switches between the group action (`.apply()`) and multiplication (`.multiply()`) based on the type of `other`. """ - if isinstance(other, onp.ndarray): + if isinstance(other, np.ndarray): return self.apply(target=other) elif isinstance(other, MatrixLieGroup): assert self.space_dim == other.space_dim @@ -77,7 +77,7 @@ def __matmul__( @classmethod @abc.abstractmethod def identity( - cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 + cls, batch_axes: Tuple[int, ...] = (), dtype: npt.DTypeLike = np.float64 ) -> Self: """Returns identity element. @@ -91,7 +91,7 @@ def identity( @classmethod @abc.abstractmethod - def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> Self: + def from_matrix(cls, matrix: npt.NDArray[np.floating]) -> Self: """Get group member from matrix representation. Args: @@ -104,17 +104,17 @@ def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> Self: # Accessors. @abc.abstractmethod - def as_matrix(self) -> onpt.NDArray[onp.floating]: + def as_matrix(self) -> npt.NDArray[np.floating]: """Get transformation as a matrix. Homogeneous for SE groups.""" @abc.abstractmethod - def parameters(self) -> onpt.NDArray[onp.floating]: + def parameters(self) -> npt.NDArray[np.floating]: """Get underlying representation.""" # Operations. @abc.abstractmethod - def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: + def apply(self, target: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]: """Applies group action to a point. Args: @@ -134,7 +134,7 @@ def multiply(self, other: Self) -> Self: @classmethod @abc.abstractmethod - def exp(cls, tangent: onpt.NDArray[onp.floating]) -> Self: + def exp(cls, tangent: npt.NDArray[np.floating]) -> Self: """Computes `expm(wedge(tangent))`. Args: @@ -145,7 +145,7 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> Self: """ @abc.abstractmethod - def log(self) -> onpt.NDArray[onp.floating]: + def log(self) -> npt.NDArray[np.floating]: """Computes `vee(logm(transformation matrix))`. Returns: @@ -153,7 +153,7 @@ def log(self) -> onpt.NDArray[onp.floating]: """ @abc.abstractmethod - def adjoint(self) -> onpt.NDArray[onp.floating]: + def adjoint(self) -> npt.NDArray[np.floating]: """Computes the adjoint, which transforms tangent vectors between tangent spaces. @@ -189,9 +189,9 @@ def normalize(self) -> Self: @abc.abstractmethod def sample_uniform( cls, - rng: onp.random.Generator, + rng: np.random.Generator, batch_axes: Tuple[int, ...] = (), - dtype: onpt.DTypeLike = onp.float64, + dtype: npt.DTypeLike = np.float64, ) -> Self: """Draw a uniform sample from the group. Translations (if applicable) are in the range [-1, 1]. @@ -234,7 +234,7 @@ class SEBase(Generic[ContainedSOType], MatrixLieGroup): def from_rotation_and_translation( cls, rotation: ContainedSOType, - translation: onpt.NDArray[onp.floating], + translation: npt.NDArray[np.floating], ) -> Self: """Construct a rigid transform from a rotation and a translation. @@ -251,7 +251,7 @@ def from_rotation_and_translation( def from_rotation(cls, rotation: ContainedSOType) -> Self: return cls.from_rotation_and_translation( rotation=rotation, - translation=onp.zeros( + translation=np.zeros( (*rotation.get_batch_axes(), cls.space_dim), dtype=rotation.parameters().dtype, ), @@ -259,7 +259,7 @@ def from_rotation(cls, rotation: ContainedSOType) -> Self: @final @classmethod - def from_translation(cls, translation: onpt.NDArray[onp.floating]) -> Self: + def from_translation(cls, translation: npt.NDArray[np.floating]) -> Self: # Extract rotation class from type parameter. assert len(cls.__orig_bases__) == 1 # type: ignore return cls.from_rotation_and_translation( @@ -272,14 +272,14 @@ def rotation(self) -> ContainedSOType: """Returns a transform's rotation term.""" @abc.abstractmethod - def translation(self) -> onpt.NDArray[onp.floating]: + def translation(self) -> npt.NDArray[np.floating]: """Returns a transform's translation term.""" # Overrides. @final @override - def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: + def apply(self, target: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]: return self.rotation() @ target + self.translation() # type: ignore @final diff --git a/src/viser/transforms/_se2.py b/src/viser/transforms/_se2.py index e68b4208e..011d47f22 100644 --- a/src/viser/transforms/_se2.py +++ b/src/viser/transforms/_se2.py @@ -3,8 +3,8 @@ import dataclasses from typing import Tuple, cast -import numpy as onp -import numpy.typing as onpt +import numpy as np +import numpy.typing as npt from typing_extensions import override from . import _base, hints @@ -31,13 +31,13 @@ class SE2( # SE2-specific. - unit_complex_xy: onpt.NDArray[onp.floating] + unit_complex_xy: npt.NDArray[np.floating] """Internal parameters. `(cos, sin, x, y)`. Shape should be `(*, 4)`.""" @override def __repr__(self) -> str: - unit_complex = onp.round(self.unit_complex_xy[..., :2], 5) - xy = onp.round(self.unit_complex_xy[..., 2:], 5) + unit_complex = np.round(self.unit_complex_xy[..., :2], 5) + xy = np.round(self.unit_complex_xy[..., 2:], 5) return f"{self.__class__.__name__}(unit_complex={unit_complex}, xy={xy})" @staticmethod @@ -46,9 +46,9 @@ def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> SE2: This is not the same as integrating over a length-3 twist. """ - cos = onp.cos(theta) - sin = onp.sin(theta) - return SE2(unit_complex_xy=onp.stack([cos, sin, x, y], axis=-1)) + cos = np.cos(theta) + sin = np.sin(theta) + return SE2(unit_complex_xy=np.stack([cos, sin, x, y], axis=-1)) # SE-specific. @@ -57,12 +57,12 @@ def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> SE2: def from_rotation_and_translation( cls, rotation: SO2, - translation: onpt.NDArray[onp.floating], + translation: npt.NDArray[np.floating], ) -> SE2: assert translation.shape[-1:] == (2,) rotation, translation = broadcast_leading_axes((rotation, translation)) return SE2( - unit_complex_xy=onp.concatenate( + unit_complex_xy=np.concatenate( [rotation.unit_complex, translation], axis=-1 ) ) @@ -72,7 +72,7 @@ def rotation(self) -> SO2: return SO2(unit_complex=self.unit_complex_xy[..., :2]) @override - def translation(self) -> onpt.NDArray[onp.floating]: + def translation(self) -> npt.NDArray[np.floating]: return self.unit_complex_xy[..., 2:] # Factory. @@ -80,17 +80,17 @@ def translation(self) -> onpt.NDArray[onp.floating]: @classmethod @override def identity( - cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 + cls, batch_axes: Tuple[int, ...] = (), dtype: npt.DTypeLike = np.float64 ) -> SE2: return SE2( - unit_complex_xy=onp.broadcast_to( - onp.array([1.0, 0.0, 0.0, 0.0], dtype=dtype), (*batch_axes, 4) + unit_complex_xy=np.broadcast_to( + np.array([1.0, 0.0, 0.0, 0.0], dtype=dtype), (*batch_axes, 4) ) ) @classmethod @override - def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SE2: + def from_matrix(cls, matrix: npt.NDArray[np.floating]) -> SE2: assert matrix.shape[-2:] == (3, 3) or matrix.shape[-2:] == (2, 3) # Currently assumes bottom row is [0, 0, 1]. return SE2.from_rotation_and_translation( @@ -101,13 +101,13 @@ def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SE2: # Accessors. @override - def parameters(self) -> onpt.NDArray[onp.floating]: + def parameters(self) -> npt.NDArray[np.floating]: return self.unit_complex_xy @override - def as_matrix(self) -> onpt.NDArray[onp.floating]: - cos, sin, x, y = onp.moveaxis(self.unit_complex_xy, -1, 0) - out = onp.stack( + def as_matrix(self) -> npt.NDArray[np.floating]: + cos, sin, x, y = np.moveaxis(self.unit_complex_xy, -1, 0) + out = np.stack( [ cos, -sin, @@ -115,9 +115,9 @@ def as_matrix(self) -> onpt.NDArray[onp.floating]: sin, cos, y, - onp.zeros_like(x), - onp.zeros_like(x), - onp.ones_like(x), + np.zeros_like(x), + np.zeros_like(x), + np.ones_like(x), ], axis=-1, ).reshape((*self.get_batch_axes(), 3, 3)) @@ -127,7 +127,7 @@ def as_matrix(self) -> onpt.NDArray[onp.floating]: @classmethod @override - def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE2: + def exp(cls, tangent: npt.NDArray[np.floating]) -> SE2: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L558 # Also see: @@ -136,32 +136,32 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE2: assert tangent.shape[-1:] == (3,) theta = tangent[..., 2] - use_taylor = onp.abs(theta) < get_epsilon(tangent.dtype) + use_taylor = np.abs(theta) < get_epsilon(tangent.dtype) - # Shim to avoid NaNs in onp.where branches, which cause failures for + # Shim to avoid NaNs in np.where branches, which cause failures for # reverse-mode AD in JAX. This isn't needed for vanilla numpy. safe_theta = cast( - onp.ndarray, - onp.where( + np.ndarray, + np.where( use_taylor, - onp.ones_like(theta), # Any non-zero value should do here. + np.ones_like(theta), # Any non-zero value should do here. theta, ), ) theta_sq = theta**2 - sin_over_theta = onp.where( + sin_over_theta = np.where( use_taylor, 1.0 - theta_sq / 6.0, - onp.sin(safe_theta) / safe_theta, + np.sin(safe_theta) / safe_theta, ) - one_minus_cos_over_theta = onp.where( + one_minus_cos_over_theta = np.where( use_taylor, 0.5 * theta - theta * theta_sq / 24.0, - (1.0 - onp.cos(safe_theta)) / safe_theta, + (1.0 - np.cos(safe_theta)) / safe_theta, ) - V = onp.stack( + V = np.stack( [ sin_over_theta, -one_minus_cos_over_theta, @@ -173,13 +173,13 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE2: return SE2.from_rotation_and_translation( rotation=SO2.from_radians(theta), - translation=onp.einsum("...ij,...j->...i", V, tangent[..., :2]).astype( + translation=np.einsum("...ij,...j->...i", V, tangent[..., :2]).astype( tangent.dtype ), ) @override - def log(self) -> onpt.NDArray[onp.floating]: + def log(self) -> npt.NDArray[np.floating]: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L160 # Also see: @@ -187,28 +187,28 @@ def log(self) -> onpt.NDArray[onp.floating]: theta = self.rotation().log()[..., 0] - cos = onp.cos(theta) + cos = np.cos(theta) cos_minus_one = cos - 1.0 half_theta = theta / 2.0 - use_taylor = onp.abs(cos_minus_one) < get_epsilon(theta.dtype) + use_taylor = np.abs(cos_minus_one) < get_epsilon(theta.dtype) - # Shim to avoid NaNs in onp.where branches, which cause failures for + # Shim to avoid NaNs in np.where branches, which cause failures for # reverse-mode AD in JAX. This isn't needed for vanilla numpy. - safe_cos_minus_one = onp.where( + safe_cos_minus_one = np.where( use_taylor, - onp.ones_like(cos_minus_one), # Any non-zero value should do here. + np.ones_like(cos_minus_one), # Any non-zero value should do here. cos_minus_one, ) - half_theta_over_tan_half_theta = onp.where( + half_theta_over_tan_half_theta = np.where( use_taylor, # Taylor approximation. 1.0 - theta**2 / 12.0, # Default. - -(half_theta * onp.sin(theta)) / safe_cos_minus_one, + -(half_theta * np.sin(theta)) / safe_cos_minus_one, ) - V_inv = onp.stack( + V_inv = np.stack( [ half_theta_over_tan_half_theta, half_theta, @@ -218,9 +218,9 @@ def log(self) -> onpt.NDArray[onp.floating]: axis=-1, ).reshape((*theta.shape, 2, 2)) - tangent = onp.concatenate( + tangent = np.concatenate( [ - onp.einsum("...ij,...j->...i", V_inv, self.translation()), + np.einsum("...ij,...j->...i", V_inv, self.translation()), theta[..., None], ], axis=-1, @@ -228,9 +228,9 @@ def log(self) -> onpt.NDArray[onp.floating]: return tangent.astype(self.unit_complex_xy.dtype) @override - def adjoint(self: SE2) -> onpt.NDArray[onp.floating]: - cos, sin, x, y = onp.moveaxis(self.unit_complex_xy, -1, 0) - return onp.stack( + def adjoint(self: SE2) -> npt.NDArray[np.floating]: + cos, sin, x, y = np.moveaxis(self.unit_complex_xy, -1, 0) + return np.stack( [ cos, -sin, @@ -238,9 +238,9 @@ def adjoint(self: SE2) -> onpt.NDArray[onp.floating]: sin, cos, -x, - onp.zeros_like(x), - onp.zeros_like(x), - onp.ones_like(x), + np.zeros_like(x), + np.zeros_like(x), + np.ones_like(x), ], axis=-1, ).reshape((*self.get_batch_axes(), 3, 3)) @@ -249,9 +249,9 @@ def adjoint(self: SE2) -> onpt.NDArray[onp.floating]: @override def sample_uniform( cls, - rng: onp.random.Generator, + rng: np.random.Generator, batch_axes: Tuple[int, ...] = (), - dtype: onpt.DTypeLike = onp.float64, + dtype: npt.DTypeLike = np.float64, ) -> SE2: return SE2.from_rotation_and_translation( SO2.sample_uniform(rng, batch_axes=batch_axes, dtype=dtype), diff --git a/src/viser/transforms/_se3.py b/src/viser/transforms/_se3.py index 2bc0b187b..d8c3cad9e 100644 --- a/src/viser/transforms/_se3.py +++ b/src/viser/transforms/_se3.py @@ -3,8 +3,8 @@ import dataclasses from typing import Tuple, cast -import numpy as onp -import numpy.typing as onpt +import numpy as np +import numpy.typing as npt from typing_extensions import override from . import _base @@ -12,12 +12,12 @@ from .utils import broadcast_leading_axes, get_epsilon -def _skew(omega: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: +def _skew(omega: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]: """Returns the skew-symmetric form of a length-3 vector.""" - wx, wy, wz = onp.moveaxis(omega, -1, 0) - zeros = onp.zeros_like(wx) - return onp.stack( + wx, wy, wz = np.moveaxis(omega, -1, 0) + zeros = np.zeros_like(wx) + return np.stack( [zeros, -wz, wy, wz, zeros, -wx, -wy, wx, zeros], axis=-1, ).reshape((*omega.shape[:-1], 3, 3)) @@ -42,13 +42,13 @@ class SE3( # SE3-specific. - wxyz_xyz: onpt.NDArray[onp.floating] + wxyz_xyz: npt.NDArray[np.floating] """Internal parameters. wxyz quaternion followed by xyz translation. Shape should be `(*, 7)`.""" @override def __repr__(self) -> str: - quat = onp.round(self.wxyz_xyz[..., :4], 5) - trans = onp.round(self.wxyz_xyz[..., 4:], 5) + quat = np.round(self.wxyz_xyz[..., :4], 5) + trans = np.round(self.wxyz_xyz[..., 4:], 5) return f"{self.__class__.__name__}(wxyz={quat}, xyz={trans})" # SE-specific. @@ -58,18 +58,18 @@ def __repr__(self) -> str: def from_rotation_and_translation( cls, rotation: SO3, - translation: onpt.NDArray[onp.floating], + translation: npt.NDArray[np.floating], ) -> SE3: assert translation.shape[-1:] == (3,) rotation, translation = broadcast_leading_axes((rotation, translation)) - return SE3(wxyz_xyz=onp.concatenate([rotation.wxyz, translation], axis=-1)) + return SE3(wxyz_xyz=np.concatenate([rotation.wxyz, translation], axis=-1)) @override def rotation(self) -> SO3: return SO3(wxyz=self.wxyz_xyz[..., :4]) @override - def translation(self) -> onpt.NDArray[onp.floating]: + def translation(self) -> npt.NDArray[np.floating]: return self.wxyz_xyz[..., 4:] # Factory. @@ -77,18 +77,18 @@ def translation(self) -> onpt.NDArray[onp.floating]: @classmethod @override def identity( - cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 + cls, batch_axes: Tuple[int, ...] = (), dtype: npt.DTypeLike = np.float64 ) -> SE3: return SE3( - wxyz_xyz=onp.broadcast_to( - onp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=dtype), + wxyz_xyz=np.broadcast_to( + np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=dtype), (*batch_axes, 7), ) ) @classmethod @override - def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SE3: + def from_matrix(cls, matrix: npt.NDArray[np.floating]) -> SE3: assert matrix.shape[-2:] == (4, 4) or matrix.shape[-2:] == (3, 4) # Currently assumes bottom row is [0, 0, 0, 1]. return SE3.from_rotation_and_translation( @@ -99,22 +99,22 @@ def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SE3: # Accessors. @override - def as_matrix(self) -> onpt.NDArray[onp.floating]: - out = onp.zeros((*self.get_batch_axes(), 4, 4), dtype=self.wxyz_xyz.dtype) + def as_matrix(self) -> npt.NDArray[np.floating]: + out = np.zeros((*self.get_batch_axes(), 4, 4), dtype=self.wxyz_xyz.dtype) out[..., :3, :3] = self.rotation().as_matrix() out[..., :3, 3] = self.translation() out[..., 3, 3] = 1.0 return out @override - def parameters(self) -> onpt.NDArray[onp.floating]: + def parameters(self) -> npt.NDArray[np.floating]: return self.wxyz_xyz # Operations. @classmethod @override - def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE3: + def exp(cls, tangent: npt.NDArray[np.floating]) -> SE3: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761 @@ -123,101 +123,101 @@ def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SE3: rotation = SO3.exp(tangent[..., 3:]) - theta_squared = onp.sum(onp.square(tangent[..., 3:]), axis=-1) + theta_squared = np.sum(np.square(tangent[..., 3:]), axis=-1) use_taylor = theta_squared < get_epsilon(theta_squared.dtype) - # Shim to avoid NaNs in onp.where branches, which cause failures for + # Shim to avoid NaNs in np.where branches, which cause failures for # reverse-mode AD in JAX. This isn't needed for vanilla numpy. theta_squared_safe = cast( - onp.ndarray, - onp.where( + np.ndarray, + np.where( use_taylor, - onp.ones_like(theta_squared), # Any non-zero value should do here. + np.ones_like(theta_squared), # Any non-zero value should do here. theta_squared, ), ) del theta_squared - theta_safe = onp.sqrt(theta_squared_safe) + theta_safe = np.sqrt(theta_squared_safe) skew_omega = _skew(tangent[..., 3:]) - V = onp.where( + V = np.where( use_taylor[..., None, None], rotation.as_matrix(), ( - onp.eye(3) - + ((1.0 - onp.cos(theta_safe)) / (theta_squared_safe))[..., None, None] + np.eye(3) + + ((1.0 - np.cos(theta_safe)) / (theta_squared_safe))[..., None, None] * skew_omega + ( - (theta_safe - onp.sin(theta_safe)) + (theta_safe - np.sin(theta_safe)) / (theta_squared_safe * theta_safe) )[..., None, None] - * onp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) + * np.einsum("...ij,...jk->...ik", skew_omega, skew_omega) ), ) return SE3.from_rotation_and_translation( rotation=rotation, - translation=onp.einsum("...ij,...j->...i", V, tangent[..., :3]).astype( + translation=np.einsum("...ij,...j->...i", V, tangent[..., :3]).astype( tangent.dtype ), ) @override - def log(self) -> onpt.NDArray[onp.floating]: + def log(self) -> npt.NDArray[np.floating]: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223 omega = self.rotation().log() - theta_squared = onp.sum(onp.square(omega), axis=-1) + theta_squared = np.sum(np.square(omega), axis=-1) use_taylor = theta_squared < get_epsilon(theta_squared.dtype) skew_omega = _skew(omega) - # Shim to avoid NaNs in onp.where branches, which cause failures for + # Shim to avoid NaNs in np.where branches, which cause failures for # reverse-mode AD in JAX. This isn't needed for vanilla numpy. - theta_squared_safe = onp.where( + theta_squared_safe = np.where( use_taylor, - onp.ones_like(theta_squared), # Any non-zero value should do here. + np.ones_like(theta_squared), # Any non-zero value should do here. theta_squared, ) del theta_squared - theta_safe = onp.sqrt(theta_squared_safe) + theta_safe = np.sqrt(theta_squared_safe) half_theta_safe = theta_safe / 2.0 - V_inv = onp.where( + V_inv = np.where( use_taylor[..., None, None], - onp.eye(3) + np.eye(3) - 0.5 * skew_omega - + onp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) / 12.0, + + np.einsum("...ij,...jk->...ik", skew_omega, skew_omega) / 12.0, ( - onp.eye(3) + np.eye(3) - 0.5 * skew_omega + ( ( 1.0 - theta_safe - * onp.cos(half_theta_safe) - / (2.0 * onp.sin(half_theta_safe)) + * np.cos(half_theta_safe) + / (2.0 * np.sin(half_theta_safe)) ) / theta_squared_safe )[..., None, None] - * onp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) + * np.einsum("...ij,...jk->...ik", skew_omega, skew_omega) ), ) - return onp.concatenate( - [onp.einsum("...ij,...j->...i", V_inv, self.translation()), omega], axis=-1 + return np.concatenate( + [np.einsum("...ij,...j->...i", V_inv, self.translation()), omega], axis=-1 ).astype(self.wxyz_xyz.dtype) @override - def adjoint(self) -> onpt.NDArray[onp.floating]: + def adjoint(self) -> npt.NDArray[np.floating]: R = self.rotation().as_matrix() - return onp.concatenate( + return np.concatenate( [ - onp.concatenate( - [R, onp.einsum("...ij,...jk->...ik", _skew(self.translation()), R)], + np.concatenate( + [R, np.einsum("...ij,...jk->...ik", _skew(self.translation()), R)], axis=-1, ), - onp.concatenate( - [onp.zeros((*self.get_batch_axes(), 3, 3), dtype=R.dtype), R], + np.concatenate( + [np.zeros((*self.get_batch_axes(), 3, 3), dtype=R.dtype), R], axis=-1, ), ], @@ -228,9 +228,9 @@ def adjoint(self) -> onpt.NDArray[onp.floating]: @override def sample_uniform( cls, - rng: onp.random.Generator, + rng: np.random.Generator, batch_axes: Tuple[int, ...] = (), - dtype: onpt.DTypeLike = onp.float64, + dtype: npt.DTypeLike = np.float64, ) -> SE3: return SE3.from_rotation_and_translation( rotation=SO3.sample_uniform(rng, batch_axes=batch_axes, dtype=dtype), diff --git a/src/viser/transforms/_so2.py b/src/viser/transforms/_so2.py index a6b9c5161..5a0d5aea3 100644 --- a/src/viser/transforms/_so2.py +++ b/src/viser/transforms/_so2.py @@ -3,8 +3,8 @@ import dataclasses from typing import Tuple -import numpy as onp -import numpy.typing as onpt +import numpy as np +import numpy.typing as npt from typing_extensions import override from . import _base, hints @@ -29,22 +29,22 @@ class SO2( # SO2-specific. - unit_complex: onpt.NDArray[onp.floating] + unit_complex: npt.NDArray[np.floating] """Internal parameters. `(cos, sin)`. Shape should be `(*, 2)`.""" @override def __repr__(self) -> str: - unit_complex = onp.round(self.unit_complex, 5) + unit_complex = np.round(self.unit_complex, 5) return f"{self.__class__.__name__}(unit_complex={unit_complex})" @staticmethod def from_radians(theta: hints.Scalar) -> SO2: """Construct a rotation object from a scalar angle.""" - cos = onp.cos(theta) - sin = onp.sin(theta) - return SO2(unit_complex=onp.stack([cos, sin], axis=-1)) + cos = np.cos(theta) + sin = np.sin(theta) + return SO2(unit_complex=np.stack([cos, sin], axis=-1)) - def as_radians(self) -> onpt.NDArray[onp.floating]: + def as_radians(self) -> npt.NDArray[np.floating]: """Compute a scalar angle from a rotation object.""" radians = self.log()[..., 0] return radians @@ -54,30 +54,30 @@ def as_radians(self) -> onpt.NDArray[onp.floating]: @classmethod @override def identity( - cls, batch_axes: Tuple[int, ...] = (), dtype: onpt.DTypeLike = onp.float64 + cls, batch_axes: Tuple[int, ...] = (), dtype: npt.DTypeLike = np.float64 ) -> SO2: return SO2( - unit_complex=onp.stack( - [onp.ones(batch_axes, dtype=dtype), onp.zeros(batch_axes, dtype=dtype)], + unit_complex=np.stack( + [np.ones(batch_axes, dtype=dtype), np.zeros(batch_axes, dtype=dtype)], axis=-1, ) ) @classmethod @override - def from_matrix(cls, matrix: onpt.NDArray[onp.floating]) -> SO2: + def from_matrix(cls, matrix: npt.NDArray[np.floating]) -> SO2: assert matrix.shape[-2:] == (2, 2) - return SO2(unit_complex=onp.array(matrix[..., :, 0])) + return SO2(unit_complex=np.array(matrix[..., :, 0])) # Accessors. @override - def as_matrix(self) -> onpt.NDArray[onp.floating]: + def as_matrix(self) -> npt.NDArray[np.floating]: cos_sin = self.unit_complex - out = onp.stack( + out = np.stack( [ # [cos, -sin], - cos_sin * onp.array([1, -1], dtype=cos_sin.dtype), + cos_sin * np.array([1, -1], dtype=cos_sin.dtype), # [sin, cos], cos_sin[..., ::-1], ], @@ -87,42 +87,42 @@ def as_matrix(self) -> onpt.NDArray[onp.floating]: return out # type: ignore @override - def parameters(self) -> onpt.NDArray[onp.floating]: + def parameters(self) -> npt.NDArray[np.floating]: return self.unit_complex # Operations. @override - def apply(self, target: onpt.NDArray[onp.floating]) -> onpt.NDArray[onp.floating]: + def apply(self, target: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]: assert target.shape[-1:] == (2,) self, target = broadcast_leading_axes((self, target)) - return onp.einsum("...ij,...j->...i", self.as_matrix(), target) + return np.einsum("...ij,...j->...i", self.as_matrix(), target) @override def multiply(self, other: SO2) -> SO2: return SO2( - unit_complex=onp.einsum( + unit_complex=np.einsum( "...ij,...j->...i", self.as_matrix(), other.unit_complex ) ) @classmethod @override - def exp(cls, tangent: onpt.NDArray[onp.floating]) -> SO2: + def exp(cls, tangent: npt.NDArray[np.floating]) -> SO2: assert tangent.shape[-1] == 1 - cos = onp.cos(tangent) - sin = onp.sin(tangent) - return SO2(unit_complex=onp.concatenate([cos, sin], axis=-1)) + cos = np.cos(tangent) + sin = np.sin(tangent) + return SO2(unit_complex=np.concatenate([cos, sin], axis=-1)) @override - def log(self) -> onpt.NDArray[onp.floating]: - return onp.arctan2( + def log(self) -> npt.NDArray[np.floating]: + return np.arctan2( self.unit_complex[..., 1, None], self.unit_complex[..., 0, None] ) @override - def adjoint(self) -> onpt.NDArray[onp.floating]: - return onp.ones((*self.get_batch_axes(), 1, 1), dtype=self.unit_complex.dtype) + def adjoint(self) -> npt.NDArray[np.floating]: + return np.ones((*self.get_batch_axes(), 1, 1), dtype=self.unit_complex.dtype) @override def inverse(self) -> SO2: @@ -134,19 +134,19 @@ def inverse(self) -> SO2: def normalize(self) -> SO2: return SO2( unit_complex=self.unit_complex - / onp.linalg.norm(self.unit_complex, axis=-1, keepdims=True) + / np.linalg.norm(self.unit_complex, axis=-1, keepdims=True) ) @classmethod @override def sample_uniform( cls, - rng: onp.random.Generator, + rng: np.random.Generator, batch_axes: Tuple[int, ...] = (), - dtype: onpt.DTypeLike = onp.float64, + dtype: npt.DTypeLike = np.float64, ) -> SO2: out = SO2.from_radians( - rng.uniform(0.0, 2.0 * onp.pi, size=batch_axes).astype(dtype=dtype) + rng.uniform(0.0, 2.0 * np.pi, size=batch_axes).astype(dtype=dtype) ) assert out.get_batch_axes() == batch_axes return out diff --git a/src/viser/transforms/hints/__init__.py b/src/viser/transforms/hints/__init__.py index dc131ab3a..fb454b2a4 100644 --- a/src/viser/transforms/hints/__init__.py +++ b/src/viser/transforms/hints/__init__.py @@ -1,11 +1,11 @@ from typing import Union -import numpy as onp -import numpy.typing as onpt +import numpy as np +import numpy.typing as npt # Type aliases Numpy arrays; primarily for function inputs. -Scalar = Union[float, onpt.NDArray[onp.floating]] +Scalar = Union[float, npt.NDArray[np.floating]] """Type alias for `Union[float, Array]`.""" diff --git a/tests/test_transforms_axioms.py b/tests/test_transforms_axioms.py index 8ec6d3398..3e644245a 100644 --- a/tests/test_transforms_axioms.py +++ b/tests/test_transforms_axioms.py @@ -5,7 +5,7 @@ from typing import Tuple, Type -import numpy as onp +import numpy as np import numpy.typing as onpt from utils import ( assert_arrays_close, @@ -46,11 +46,11 @@ def test_identity( assert_transforms_close(transform, transform @ identity) assert_arrays_close( transform.as_matrix(), - onp.einsum("...ij,...jk->...ik", identity.as_matrix(), transform.as_matrix()), + np.einsum("...ij,...jk->...ik", identity.as_matrix(), transform.as_matrix()), ) assert_arrays_close( transform.as_matrix(), - onp.einsum("...ij,...jk->...ik", transform.as_matrix(), identity.as_matrix()), + np.einsum("...ij,...jk->...ik", transform.as_matrix(), identity.as_matrix()), ) @@ -66,22 +66,22 @@ def test_inverse( assert_transforms_close(identity, Group.multiply(transform, transform.inverse())) assert_transforms_close(identity, Group.multiply(transform.inverse(), transform)) assert_arrays_close( - onp.broadcast_to( - onp.eye(Group.matrix_dim, dtype=dtype), + np.broadcast_to( + np.eye(Group.matrix_dim, dtype=dtype), (*batch_axes, Group.matrix_dim, Group.matrix_dim), ), - onp.einsum( + np.einsum( "...ij,...jk->...ik", transform.as_matrix(), transform.inverse().as_matrix(), ), ) assert_arrays_close( - onp.broadcast_to( - onp.eye(Group.matrix_dim, dtype=dtype), + np.broadcast_to( + np.eye(Group.matrix_dim, dtype=dtype), (*batch_axes, Group.matrix_dim, Group.matrix_dim), ), - onp.einsum( + np.einsum( "...ij,...jk->...ik", transform.inverse().as_matrix(), transform.as_matrix(), diff --git a/tests/test_transforms_bijective.py b/tests/test_transforms_bijective.py index a9f0c835c..4878bd045 100644 --- a/tests/test_transforms_bijective.py +++ b/tests/test_transforms_bijective.py @@ -2,8 +2,8 @@ from typing import Tuple, Type -import numpy as onp -import numpy.typing as onpt +import numpy as np +import numpy.typing as npt from hypothesis import given, settings from hypothesis import strategies as st from utils import ( @@ -18,7 +18,7 @@ @general_group_test def test_sample_uniform_valid( - Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: npt.DTypeLike ): """Check that sample_uniform() returns valid group members.""" T = sample_transform( @@ -31,7 +31,7 @@ def test_sample_uniform_valid( @given(_random_module=st.random_module()) def test_so2_from_to_radians_bijective(_random_module): """Check that we can convert from and to radians.""" - radians = onp.random.uniform(low=-onp.pi, high=onp.pi) + radians = np.random.uniform(low=-np.pi, high=np.pi) assert_arrays_close(vtf.SO2.from_radians(radians).as_radians(), radians) @@ -39,7 +39,7 @@ def test_so2_from_to_radians_bijective(_random_module): @given(_random_module=st.random_module()) def test_so3_xyzw_bijective(_random_module): """Check that we can convert between xyzw and wxyz quaternions.""" - T = sample_transform(vtf.SO3, (), dtype=onp.float64) + T = sample_transform(vtf.SO3, (), dtype=np.float64) assert_transforms_close(T, vtf.SO3.from_quaternion_xyzw(T.as_quaternion_xyzw())) @@ -47,14 +47,14 @@ def test_so3_xyzw_bijective(_random_module): @given(_random_module=st.random_module()) def test_so3_rpy_bijective(_random_module): """Check that we can convert between quaternions and Euler angles.""" - T = sample_transform(vtf.SO3, (), dtype=onp.float64) + T = sample_transform(vtf.SO3, (), dtype=np.float64) rpy = T.as_rpy_radians() assert_transforms_close(T, vtf.SO3.from_rpy_radians(rpy.roll, rpy.pitch, rpy.yaw)) @general_group_test def test_log_exp_bijective( - Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: npt.DTypeLike ): """Check 1-to-1 mapping for log <=> exp operations.""" transform = sample_transform(Group, batch_axes, dtype) @@ -71,7 +71,7 @@ def test_log_exp_bijective( @general_group_test def test_inverse_bijective( - Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: npt.DTypeLike ): """Check inverse of inverse.""" transform = sample_transform(Group, batch_axes, dtype) @@ -80,7 +80,7 @@ def test_inverse_bijective( @general_group_test def test_matrix_bijective( - Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: npt.DTypeLike ): """Check that we can convert to and from matrices.""" transform = sample_transform(Group, batch_axes, dtype) @@ -89,22 +89,22 @@ def test_matrix_bijective( @general_group_test def test_adjoint( - Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: npt.DTypeLike ): """Check adjoint definition.""" transform = sample_transform(Group, batch_axes, dtype) - omega = onp.random.randn(*batch_axes, Group.tangent_dim).astype(dtype) + omega = np.random.randn(*batch_axes, Group.tangent_dim).astype(dtype) assert (transform @ Group.exp(omega)).parameters().dtype == dtype assert_transforms_close( transform @ Group.exp(omega), - Group.exp(onp.einsum("...ij,...j->...i", transform.adjoint(), omega)) + Group.exp(np.einsum("...ij,...j->...i", transform.adjoint(), omega)) @ transform, ) @general_group_test def test_repr( - Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: npt.DTypeLike ): """Smoke test for __repr__ implementations.""" transform = sample_transform(Group, batch_axes, dtype) @@ -113,17 +113,17 @@ def test_repr( @general_group_test def test_apply( - Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: npt.DTypeLike ): """Check group action interfaces.""" T_w_b = sample_transform(Group, batch_axes, dtype) - p_b = onp.random.randn(*batch_axes, Group.space_dim).astype(dtype) + p_b = np.random.randn(*batch_axes, Group.space_dim).astype(dtype) if Group.matrix_dim == Group.space_dim: assert_arrays_close( T_w_b @ p_b, T_w_b.apply(p_b), - onp.einsum("...ij,...j->...i", T_w_b.as_matrix(), p_b), + np.einsum("...ij,...j->...i", T_w_b.as_matrix(), p_b), ) else: # Homogeneous coordinates. @@ -131,36 +131,34 @@ def test_apply( assert_arrays_close( T_w_b @ p_b, T_w_b.apply(p_b), - onp.einsum( + np.einsum( "...ij,...j->...i", T_w_b.as_matrix(), - onp.concatenate([p_b, onp.ones_like(p_b[..., :1])], axis=-1), + np.concatenate([p_b, np.ones_like(p_b[..., :1])], axis=-1), )[..., :-1], ) @general_group_test def test_multiply( - Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike + Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], dtype: npt.DTypeLike ): """Check multiply interfaces.""" T_w_b = sample_transform(Group, batch_axes, dtype) T_b_a = sample_transform(Group, batch_axes, dtype) assert_arrays_close( - onp.einsum( - "...ij,...jk->...ik", T_w_b.as_matrix(), T_w_b.inverse().as_matrix() - ), - onp.broadcast_to( - onp.eye(Group.matrix_dim, dtype=dtype), + np.einsum("...ij,...jk->...ik", T_w_b.as_matrix(), T_w_b.inverse().as_matrix()), + np.broadcast_to( + np.eye(Group.matrix_dim, dtype=dtype), (*batch_axes, Group.matrix_dim, Group.matrix_dim), ), ) assert_arrays_close( - onp.einsum( - "...ij,...jk->...ik", T_w_b.as_matrix(), onp.linalg.inv(T_w_b.as_matrix()) + np.einsum( + "...ij,...jk->...ik", T_w_b.as_matrix(), np.linalg.inv(T_w_b.as_matrix()) ), - onp.broadcast_to( - onp.eye(Group.matrix_dim, dtype=dtype), + np.broadcast_to( + np.eye(Group.matrix_dim, dtype=dtype), (*batch_axes, Group.matrix_dim, Group.matrix_dim), ), ) diff --git a/tests/test_transforms_ops.py b/tests/test_transforms_ops.py index 6321b8896..e65adf9c5 100644 --- a/tests/test_transforms_ops.py +++ b/tests/test_transforms_ops.py @@ -2,7 +2,7 @@ from typing import Tuple, Type -import numpy as onp +import numpy as np import numpy.typing as onpt from utils import ( assert_arrays_close, @@ -64,10 +64,10 @@ def test_adjoint( ): """Check adjoint definition.""" transform = sample_transform(Group, batch_axes, dtype) - omega = onp.random.randn(*batch_axes, Group.tangent_dim).astype(dtype=dtype) + omega = np.random.randn(*batch_axes, Group.tangent_dim).astype(dtype=dtype) assert_transforms_close( transform @ Group.exp(omega), - Group.exp(onp.einsum("...ij,...j->...i", transform.adjoint(), omega)) + Group.exp(np.einsum("...ij,...j->...i", transform.adjoint(), omega)) @ transform, ) @@ -87,13 +87,13 @@ def test_apply( ): """Check group action interfaces.""" T_w_b = sample_transform(Group, batch_axes, dtype) - p_b = onp.random.randn(*batch_axes, Group.space_dim).astype(dtype) + p_b = np.random.randn(*batch_axes, Group.space_dim).astype(dtype) if Group.matrix_dim == Group.space_dim: assert_arrays_close( T_w_b @ p_b, T_w_b.apply(p_b), - onp.einsum("...ij,...j->...i", T_w_b.as_matrix(), p_b), + np.einsum("...ij,...j->...i", T_w_b.as_matrix(), p_b), ) else: # Homogeneous coordinates. @@ -101,10 +101,10 @@ def test_apply( assert_arrays_close( T_w_b @ p_b, T_w_b.apply(p_b), - onp.einsum( + np.einsum( "...ij,...j->...i", T_w_b.as_matrix(), - onp.concatenate([p_b, onp.ones_like(p_b[..., :1])], axis=-1), + np.concatenate([p_b, np.ones_like(p_b[..., :1])], axis=-1), )[..., :-1], ) @@ -117,11 +117,11 @@ def test_multiply( T_w_b = sample_transform(Group, batch_axes, dtype) T_b_a = sample_transform(Group, batch_axes, dtype) assert_arrays_close( - onp.einsum( - "...ij,...jk->...ik", T_w_b.as_matrix(), onp.linalg.inv(T_w_b.as_matrix()) + np.einsum( + "...ij,...jk->...ik", T_w_b.as_matrix(), np.linalg.inv(T_w_b.as_matrix()) ), - onp.broadcast_to( - onp.eye(Group.matrix_dim, dtype=dtype), + np.broadcast_to( + np.eye(Group.matrix_dim, dtype=dtype), (*batch_axes, Group.matrix_dim, Group.matrix_dim), ), ) diff --git a/tests/utils.py b/tests/utils.py index ed7346e1e..d06f0b07d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,19 +2,18 @@ import random from typing import Any, Callable, Tuple, Type, TypeVar, Union, cast -import numpy as onp -import numpy.typing as onpt +import numpy as np +import numpy.typing as npt import pytest +import viser.transforms as vtf from hypothesis import given, settings from hypothesis import strategies as st -import viser.transforms as vtf - T = TypeVar("T", bound=vtf.MatrixLieGroup) def sample_transform( - Group: Type[T], batch_axes: Tuple[int, ...], dtype: onpt.DTypeLike + Group: Type[T], batch_axes: Tuple[int, ...], dtype: npt.DTypeLike ) -> T: """Sample a random transform from a group.""" seed = random.getrandbits(32) @@ -25,7 +24,7 @@ def sample_transform( return cast( T, Group.sample_uniform( - onp.random.default_rng(seed), batch_axes=batch_axes, dtype=dtype + np.random.default_rng(seed), batch_axes=batch_axes, dtype=dtype ), ) elif strategy == 1: @@ -33,7 +32,7 @@ def sample_transform( return cast( T, Group.exp( - onp.random.randn(*batch_axes, Group.tangent_dim).astype(dtype=dtype) + np.random.randn(*batch_axes, Group.tangent_dim).astype(dtype=dtype) ), ) elif strategy == 2: @@ -41,7 +40,7 @@ def sample_transform( return cast( T, Group.exp( - onp.random.randn(*batch_axes, Group.tangent_dim).astype(dtype=dtype) + np.random.randn(*batch_axes, Group.tangent_dim).astype(dtype=dtype) * 1e-7 ), ) @@ -50,16 +49,16 @@ def sample_transform( def general_group_test( - f: Callable[[Type[vtf.MatrixLieGroup], Tuple[int, ...], onpt.DTypeLike], None], + f: Callable[[Type[vtf.MatrixLieGroup], Tuple[int, ...], npt.DTypeLike], None], max_examples: int = 15, -) -> Callable[[Type[vtf.MatrixLieGroup], Tuple[int, ...], onpt.DTypeLike, Any], None]: +) -> Callable[[Type[vtf.MatrixLieGroup], Tuple[int, ...], npt.DTypeLike, Any], None]: """Decorator for defining tests that run on all group types.""" # Disregard unused argument. def f_wrapped( Group: Type[vtf.MatrixLieGroup], batch_axes: Tuple[int, ...], - dtype: onpt.DTypeLike, + dtype: npt.DTypeLike, _random_module, ) -> None: f(Group, batch_axes, dtype) @@ -94,7 +93,7 @@ def f_wrapped( # Parametrize tests with each group type. f_wrapped = pytest.mark.parametrize( "dtype", - [onp.float32, onp.float64], + [np.float32, np.float64], )(f_wrapped) return f_wrapped @@ -108,32 +107,31 @@ def assert_transforms_close(a: vtf.MatrixLieGroup, b: vtf.MatrixLieGroup): assert_arrays_close(a.as_matrix(), b.as_matrix()) # Flip signs for quaternions. - # We use `jnp.asarray` here in case inputs are onp arrays and don't support `.at()`. p1 = a.parameters().copy() p2 = b.parameters().copy() if isinstance(a, vtf.SO3): - p1 = p1 * onp.sign(onp.sum(p1, axis=-1, keepdims=True)) - p2 = p2 * onp.sign(onp.sum(p2, axis=-1, keepdims=True)) + p1 = p1 * np.sign(np.sum(p1, axis=-1, keepdims=True)) + p2 = p2 * np.sign(np.sum(p2, axis=-1, keepdims=True)) elif isinstance(a, vtf.SE3): - p1[..., :4] *= onp.sign(onp.sum(p1[..., :4], axis=-1, keepdims=True)) - p2[..., :4] *= onp.sign(onp.sum(p2[..., :4], axis=-1, keepdims=True)) + p1[..., :4] *= np.sign(np.sum(p1[..., :4], axis=-1, keepdims=True)) + p2[..., :4] *= np.sign(np.sum(p2[..., :4], axis=-1, keepdims=True)) # Make sure parameters are equal. assert_arrays_close(p1, p2) -def assert_arrays_close(*arrays: Union[onpt.NDArray[onp.float64], float]): +def assert_arrays_close(*arrays: Union[npt.NDArray[np.float64], float]): """Make sure two arrays are close. (and not NaN)""" for array1, array2 in zip(arrays[:-1], arrays[1:]): - assert onp.asarray(array1).dtype == onp.asarray(array2).dtype + assert np.asarray(array1).dtype == np.asarray(array2).dtype - if isinstance(array1, (float, int)) or array1.dtype == onp.float64: + if isinstance(array1, (float, int)) or array1.dtype == np.float64: rtol = 1e-7 atol = 1e-7 else: rtol = 1e-3 atol = 1e-3 - onp.testing.assert_allclose(array1, array2, rtol=rtol, atol=atol) - assert not onp.any(onp.isnan(array1)) - assert not onp.any(onp.isnan(array2)) + np.testing.assert_allclose(array1, array2, rtol=rtol, atol=atol) + assert not np.any(np.isnan(array1)) + assert not np.any(np.isnan(array2)) From 6aaf7a9568a6a31dd592d18cb2fd7c2fcb80a958 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 24 Sep 2024 03:43:29 -0700 Subject: [PATCH 14/15] ruff --- tests/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index d06f0b07d..9177d2943 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,10 +5,11 @@ import numpy as np import numpy.typing as npt import pytest -import viser.transforms as vtf from hypothesis import given, settings from hypothesis import strategies as st +import viser.transforms as vtf + T = TypeVar("T", bound=vtf.MatrixLieGroup) From b2cf2e48df993503e94ee7c0a0ee13c8cf9a8051 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 24 Sep 2024 03:49:00 -0700 Subject: [PATCH 15/15] pyright --- src/viser/_gui_api.py | 65 +++++++++++++++------------------------ src/viser/_gui_handles.py | 14 ++++----- 2 files changed, 31 insertions(+), 48 deletions(-) diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index 461b52f07..a459efa83 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -483,14 +483,10 @@ def add_folder( ) return GuiFolderHandle( _GuiHandleState( + folder_container_id, self, None, props=props, - update_timestamp=0.0, - update_cb=[], - is_button=False, - sync_cb=None, - id=folder_container_id, parent_container_id=self._get_container_id(), ) ) @@ -541,27 +537,26 @@ def add_tab_group( tab_group_id = _make_unique_id() order = _apply_default_order(order) - self._websock_interface.queue_message( - _messages.GuiTabGroupMessage( - id=tab_group_id, - container_id=self._get_container_id(), - props=_messages.GuiTabGroupProps( - order=order, - _tab_labels=(), - visible=visible, - _tab_icons_html=(), - _tab_container_ids=(), - ), - ) + message = _messages.GuiTabGroupMessage( + id=tab_group_id, + container_id=self._get_container_id(), + props=_messages.GuiTabGroupProps( + order=order, + _tab_labels=(), + visible=visible, + _tab_icons_html=(), + _tab_container_ids=(), + ), ) + self._websock_interface.queue_message(message) return GuiTabGroupHandle( - _tab_group_id=tab_group_id, - _labels=[], - _icons_html=[], - _tabs=[], - _gui_api=self, - _parent_container_id=self._get_container_id(), - _order=order, + _GuiHandleState( + message.id, + self, + value=None, + props=message.props, + parent_container_id=message.container_id, + ) ) def add_markdown( @@ -595,14 +590,10 @@ def add_markdown( handle = GuiMarkdownHandle( _GuiHandleState( + message.id, self, None, props=message.props, - update_timestamp=0.0, - update_cb=[], - is_button=False, - sync_cb=None, - id=message.id, parent_container_id=message.container_id, ), _content=content, @@ -678,14 +669,10 @@ def add_plotly( handle = GuiPlotlyHandle( _GuiHandleState( + message.id, self, - None, + value=None, props=message.props, - update_timestamp=0.0, - update_cb=[], - is_button=False, - sync_cb=None, - id=message.id, parent_container_id=message.container_id, ), _figure=figure, @@ -1243,14 +1230,10 @@ def add_progress_bar( self._websock_interface.queue_message(message) handle = GuiProgressBarHandle( _GuiHandleState( + message.id, self, - value, + value=value, props=message.props, - update_timestamp=0.0, - update_cb=[], - is_button=False, - sync_cb=None, - id=message.id, parent_container_id=message.container_id, ), ) diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index dac756918..5138b48fc 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -87,25 +87,25 @@ class GuiPropsProtocol(Protocol): class _GuiHandleState(Generic[T]): """Internal API for GUI elements.""" + id: str gui_api: GuiApi value: T props: GuiPropsProtocol - update_timestamp: float - parent_container_id: str """Container that this GUI input was placed into.""" - update_cb: list[Callable[[GuiEvent], None]] + update_timestamp: float = 0.0 + update_cb: list[Callable[[GuiEvent], None]] = dataclasses.field( + default_factory=list + ) """Registered functions to call when this input is updated.""" - is_button: bool + is_button: bool = False """Indicates a button element, which requires special handling.""" - sync_cb: Callable[[ClientId, dict[str, Any]], None] | None + sync_cb: Callable[[ClientId, dict[str, Any]], None] | None = None """Callback for synchronizing inputs across clients.""" - id: str - class _OverridableGuiPropApi: """Mixin that allows reading/assigning properties defined in each scene node message."""