From 5f5977beeb8905867f9984f258054002c24edc35 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sun, 13 Oct 2024 23:00:30 -0700 Subject: [PATCH] Use float16 buffers for point clouds (#296) --- src/viser/_messages.py | 4 ++-- src/viser/_scene_api.py | 2 +- src/viser/_scene_handles.py | 2 ++ src/viser/client/src/SceneTree.tsx | 2 +- src/viser/client/src/ThreeAssets.tsx | 9 +++++++-- 5 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 849f09bf..9946ab51 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -332,7 +332,7 @@ class PointCloudMessage(Message, tag="SceneNodeMessage"): @dataclasses.dataclass class PointCloudProps: - points: npt.NDArray[np.float32] + points: npt.NDArray[np.float16] """Location of points. Should have shape (N, 3). Synchronized automatically when assigned.""" colors: npt.NDArray[np.uint8] """Colors of points. Should have shape (N, 3) or (3,). Synchronized automatically when assigned.""" @@ -347,7 +347,7 @@ def __post_init__(self): assert self.points.shape[-1] == 3 # Check dtypes. - assert self.points.dtype == np.float32 + assert self.points.dtype == np.float16 assert self.colors.dtype == np.uint8 diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 88fb858b..e6fe8b26 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -965,7 +965,7 @@ def add_point_cloud( message = _messages.PointCloudMessage( name=name, props=_messages.PointCloudProps( - points=points.astype(np.float32), + points=points.astype(np.float16), colors=colors_cast, point_size=point_size, point_ball_norm={ diff --git a/src/viser/_scene_handles.py b/src/viser/_scene_handles.py index 5a0432a9..acc22df4 100644 --- a/src/viser/_scene_handles.py +++ b/src/viser/_scene_handles.py @@ -55,6 +55,8 @@ def __setattr__(self, name: str, value: Any) -> None: hint = self._prop_hints[name] if hint == onpt.NDArray[np.float32]: value = value.astype(np.float32) + elif hint == onpt.NDArray[np.float16]: + value = value.astype(np.float16) elif hint == onpt.NDArray[np.uint8] and "color" in name: value = colors_to_uint8(value) diff --git a/src/viser/client/src/SceneTree.tsx b/src/viser/client/src/SceneTree.tsx index 51adcf53..bb5d9fec 100644 --- a/src/viser/client/src/SceneTree.tsx +++ b/src/viser/client/src/SceneTree.tsx @@ -252,7 +252,7 @@ function useObjectFactory(message: SceneNodeMessage | undefined): { pointSize={message.props.point_size} pointBallNorm={message.props.point_ball_norm} points={ - new Float32Array( + new Uint16Array( // (contains float16) message.props.points.buffer.slice( message.props.points.byteOffset, message.props.points.byteOffset + diff --git a/src/viser/client/src/ThreeAssets.tsx b/src/viser/client/src/ThreeAssets.tsx index 973c7d61..36b49bae 100644 --- a/src/viser/client/src/ThreeAssets.tsx +++ b/src/viser/client/src/ThreeAssets.tsx @@ -59,6 +59,8 @@ const originMaterial = new THREE.MeshBasicMaterial({ color: 0xecec00 }); const PointCloudMaterial = /* @__PURE__ */ shaderMaterial( { scale: 1.0, point_ball_norm: 0.0 }, ` + precision mediump float; + varying vec3 vPosition; varying vec3 vColor; // in the vertex shader uniform float scale; @@ -94,14 +96,17 @@ export const PointCloud = React.forwardRef< pointSize: number; /** We visualize each point as a 2D ball, which is defined by some norm. */ pointBallNorm: number; - points: Float32Array; + points: Uint16Array; // Contains float16. colors: Uint8Array; } >(function PointCloud(props, ref) { const getThreeState = useThree((state) => state.get); const geometry = new THREE.BufferGeometry(); - geometry.setAttribute("position", new THREE.BufferAttribute(props.points, 3)); + geometry.setAttribute( + "position", + new THREE.Float16BufferAttribute(props.points, 3), + ); geometry.computeBoundingSphere(); geometry.setAttribute( "color",