Skip to content

Commit

Permalink
Use float16 buffers for point clouds (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi authored Oct 14, 2024
1 parent bf31c28 commit 5f5977b
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/viser/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ class PointCloudMessage(Message, tag="SceneNodeMessage"):

@dataclasses.dataclass
class PointCloudProps:
points: npt.NDArray[np.float32]
points: npt.NDArray[np.float16]
"""Location of points. Should have shape (N, 3). Synchronized automatically when assigned."""
colors: npt.NDArray[np.uint8]
"""Colors of points. Should have shape (N, 3) or (3,). Synchronized automatically when assigned."""
Expand All @@ -347,7 +347,7 @@ def __post_init__(self):
assert self.points.shape[-1] == 3

# Check dtypes.
assert self.points.dtype == np.float32
assert self.points.dtype == np.float16
assert self.colors.dtype == np.uint8


Expand Down
2 changes: 1 addition & 1 deletion src/viser/_scene_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def add_point_cloud(
message = _messages.PointCloudMessage(
name=name,
props=_messages.PointCloudProps(
points=points.astype(np.float32),
points=points.astype(np.float16),
colors=colors_cast,
point_size=point_size,
point_ball_norm={
Expand Down
2 changes: 2 additions & 0 deletions src/viser/_scene_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __setattr__(self, name: str, value: Any) -> None:
hint = self._prop_hints[name]
if hint == onpt.NDArray[np.float32]:
value = value.astype(np.float32)
elif hint == onpt.NDArray[np.float16]:
value = value.astype(np.float16)
elif hint == onpt.NDArray[np.uint8] and "color" in name:
value = colors_to_uint8(value)

Expand Down
2 changes: 1 addition & 1 deletion src/viser/client/src/SceneTree.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ function useObjectFactory(message: SceneNodeMessage | undefined): {
pointSize={message.props.point_size}
pointBallNorm={message.props.point_ball_norm}
points={
new Float32Array(
new Uint16Array( // (contains float16)
message.props.points.buffer.slice(
message.props.points.byteOffset,
message.props.points.byteOffset +
Expand Down
9 changes: 7 additions & 2 deletions src/viser/client/src/ThreeAssets.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ const originMaterial = new THREE.MeshBasicMaterial({ color: 0xecec00 });
const PointCloudMaterial = /* @__PURE__ */ shaderMaterial(
{ scale: 1.0, point_ball_norm: 0.0 },
`
precision mediump float;
varying vec3 vPosition;
varying vec3 vColor; // in the vertex shader
uniform float scale;
Expand Down Expand Up @@ -94,14 +96,17 @@ export const PointCloud = React.forwardRef<
pointSize: number;
/** We visualize each point as a 2D ball, which is defined by some norm. */
pointBallNorm: number;
points: Float32Array;
points: Uint16Array; // Contains float16.
colors: Uint8Array;
}
>(function PointCloud(props, ref) {
const getThreeState = useThree((state) => state.get);

const geometry = new THREE.BufferGeometry();
geometry.setAttribute("position", new THREE.BufferAttribute(props.points, 3));
geometry.setAttribute(
"position",
new THREE.Float16BufferAttribute(props.points, 3),
);
geometry.computeBoundingSphere();
geometry.setAttribute(
"color",
Expand Down

0 comments on commit 5f5977b

Please sign in to comment.