diff --git a/src/viser/_message_api.py b/src/viser/_message_api.py index 47b37848b..fa785bc79 100644 --- a/src/viser/_message_api.py +++ b/src/viser/_message_api.py @@ -74,7 +74,7 @@ def _hex_from_hls(h: float, l: float, s: float) -> str: def _colors_to_uint8(colors: onp.ndarray) -> onpt.NDArray[onp.uint8]: """Convert intensity values to uint8. We assume the range [0,1] for floats, and - [0,255] for integers.""" + [0,255] for integers. Accepts any shape.""" if colors.dtype != onp.uint8: if onp.issubdtype(colors.dtype, onp.floating): colors = onp.clip(colors * 255.0, 0, 255).astype(onp.uint8) @@ -429,18 +429,39 @@ def add_point_cloud( self, name: str, points: onp.ndarray, - colors: onp.ndarray, + colors: onp.ndarray | Tuple[float, float, float], point_size: float = 0.1, wxyz: Tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: Tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> PointCloudHandle: - """Add a point cloud to the scene.""" + """Add a point cloud to the scene. + + Args: + name: Name of scene node. Determines location in kinematic tree. + points: Location of points. Should have shape (N, 3). + colors: Colors of points. Should have shape (N, 3) or (3,). + point_size: Size of each point. + wxyz: Quaternion rotation to parent frame from local frame (R_pl). + position: Translation from parent frame to local frame (t_pl). + visible: Whether or not this point cloud is initially visible. + """ + colors_cast = _colors_to_uint8(onp.asarray(colors)) + assert ( + len(points.shape) == 2 and points.shape[-1] == 3 + ), "Shape of points should be (N, 3)." + assert colors_cast.shape == points.shape or colors_cast.shape == ( + 3, + ), "Shape of colors should be (N, 3) or (3,)." + + if colors_cast.shape == (3,): + colors_cast = onp.tile(colors_cast[None, :], reps=(points.shape[0], 1)) + self._queue( _messages.PointCloudMessage( name=name, points=points.astype(onp.float32), - colors=_colors_to_uint8(colors), + colors=colors_cast, point_size=point_size, ) )