From 6e0000bc2c8bc83249381e9a96fd372bb206d312 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Mon, 8 Jan 2024 16:48:59 -0800 Subject: [PATCH] Improved point cloud shader (#158) * Improved point cloud shader * Resolve eslint --- examples/07_record3d_visualizer.py | 1 + src/viser/_message_api.py | 11 +++ src/viser/_messages.py | 3 +- src/viser/client/src/ThreeAssets.tsx | 89 ++++++++++++++++++++- src/viser/client/src/WebsocketInterface.tsx | 64 +++++---------- src/viser/client/src/WebsocketMessages.tsx | 1 + 6 files changed, 121 insertions(+), 48 deletions(-) diff --git a/examples/07_record3d_visualizer.py b/examples/07_record3d_visualizer.py index 9dcb1b6a9..a4e83ec0f 100644 --- a/examples/07_record3d_visualizer.py +++ b/examples/07_record3d_visualizer.py @@ -105,6 +105,7 @@ def _(_) -> None: points=position, colors=color, point_size=0.01, + point_shape="rounded", ) # Place the frustum. diff --git a/src/viser/_message_api.py b/src/viser/_message_api.py index 2ee172ded..04884da91 100644 --- a/src/viser/_message_api.py +++ b/src/viser/_message_api.py @@ -661,6 +661,9 @@ def add_point_cloud( points: onp.ndarray, colors: onp.ndarray | Tuple[float, float, float], point_size: float = 0.1, + point_shape: Literal[ + "square", "diamond", "circle", "rounded", "sparkle" + ] = "square", wxyz: Tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: Tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, @@ -672,6 +675,7 @@ def add_point_cloud( points: Location of points. Should have shape (N, 3). colors: Colors of points. Should have shape (N, 3) or (3,). point_size: Size of each point. + point_shape: Shape to draw each point. wxyz: Quaternion rotation to parent frame from local frame (R_pl). position: Translation to parent frame from local frame (t_pl). visible: Whether or not this scene node is initially visible. @@ -697,6 +701,13 @@ def add_point_cloud( points=points.astype(onp.float32), colors=colors_cast, point_size=point_size, + point_ball_norm={ + "square": 0.0, + "diamond": 1.0, + "circle": 2.0, + "rounded": 3.0, + "sparkle": 0.6, + }[point_shape], ) ) return PointCloudHandle._make(self, name, wxyz, position, visible) diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 24819e3cb..2a480be60 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -158,7 +158,8 @@ class PointCloudMessage(Message): name: str points: onpt.NDArray[onp.float32] colors: onpt.NDArray[onp.uint8] - point_size: float = 0.1 + point_size: float + point_ball_norm: float def __post_init__(self): # Check shapes. diff --git a/src/viser/client/src/ThreeAssets.tsx b/src/viser/client/src/ThreeAssets.tsx index 726645f79..efa5348a4 100644 --- a/src/viser/client/src/ThreeAssets.tsx +++ b/src/viser/client/src/ThreeAssets.tsx @@ -1,5 +1,5 @@ -import { Instance, Instances } from "@react-three/drei"; -import { createPortal, useFrame } from "@react-three/fiber"; +import { Instance, Instances, shaderMaterial } from "@react-three/drei"; +import { createPortal, useFrame, useThree } from "@react-three/fiber"; import { Outlines } from "./Outlines"; import React from "react"; import * as THREE from "three"; @@ -47,6 +47,91 @@ type AllPossibleThreeJSMaterials = const originGeom = new THREE.SphereGeometry(1.0); const originMaterial = new THREE.MeshBasicMaterial({ color: 0xecec00 }); +const PointCloudMaterial = /* @__PURE__ */ shaderMaterial( + { scale: 1.0, point_ball_norm: 0.0 }, + ` + varying vec3 vPosition; + varying vec3 vColor; // in the vertex shader + uniform float scale; + + void main() { + vPosition = position; + vColor = color; + vec4 world_pos = modelViewMatrix * vec4(position, 1.0); + gl_Position = projectionMatrix * world_pos; + gl_PointSize = (scale / -world_pos.z); + } + `, + `varying vec3 vPosition; + varying vec3 vColor; + uniform float point_ball_norm; + + void main() { + if (!isinf(point_ball_norm)) { + float r = pow( + pow(abs(gl_PointCoord.x - 0.5), point_ball_norm) + + pow(abs(gl_PointCoord.y - 0.5), point_ball_norm), + 1.0 / point_ball_norm); + if (r > 0.5) discard; + } + gl_FragColor = vec4(vColor, 1.0); + } + `, +); + +export const PointCloud = React.forwardRef< + THREE.Points, + { + pointSize: number; + /** We visualize each point as a 2D ball, which is defined by some norm. */ + pointBallNorm: number; + points: Float32Array; + colors: Float32Array; + } +>(function PointCloud(props, ref) { + const { gl, camera } = useThree(); + + const geometry = new THREE.BufferGeometry(); + geometry.setAttribute( + "position", + new THREE.Float32BufferAttribute(props.points, 3), + ); + geometry.computeBoundingSphere(); + geometry.setAttribute( + "color", + new THREE.Float32BufferAttribute(props.colors, 3), + ); + + const [material] = React.useState( + () => new PointCloudMaterial({ vertexColors: true }), + ); + material.uniforms.scale.value = 10.0; + material.uniforms.point_ball_norm.value = props.pointBallNorm; + + React.useEffect(() => { + return () => { + material.dispose(); + geometry.dispose(); + }; + }); + + const rendererSize = new THREE.Vector2(); + useFrame(() => { + // Match point scale to behavior of THREE.PointsMaterial(). + if (material === undefined) return; + // point px height / actual height = point meters height / frustum meters height + // frustum meters height = math.tan(fov / 2.0) * z + // point px height = (point meters height / math.tan(fov / 2.0) * actual height) / z + material.uniforms.scale.value = + (props.pointSize / + Math.tan( + (((camera as THREE.PerspectiveCamera).fov / 180.0) * Math.PI) / 2.0, + )) * + gl.getSize(rendererSize).height; + }); + return ; +}); + /** Component for rendering the contents of GLB files. */ export const GlbAsset = React.forwardRef< THREE.Group, diff --git a/src/viser/client/src/WebsocketInterface.tsx b/src/viser/client/src/WebsocketInterface.tsx index 835f489bf..bd58f80f4 100644 --- a/src/viser/client/src/WebsocketInterface.tsx +++ b/src/viser/client/src/WebsocketInterface.tsx @@ -15,6 +15,7 @@ import { CoordinateFrame, GlbAsset, OutlinesIfHovered, + PointCloud, } from "./ThreeAssets"; import { FileDownloadPart, @@ -195,52 +196,25 @@ function useMessageHandler() { // Add a point cloud. case "PointCloudMessage": { - const geometry = new THREE.BufferGeometry(); - const pointCloudMaterial = new THREE.PointsMaterial({ - size: message.point_size, - vertexColors: true, - toneMapped: false, - }); - - // Reinterpret cast: uint8 buffer => float32 for positions. - geometry.setAttribute( - "position", - new THREE.Float32BufferAttribute( - new Float32Array( - message.points.buffer.slice( - message.points.byteOffset, - message.points.byteOffset + message.points.byteLength, - ), - ), - 3, - ), - ); - geometry.computeBoundingSphere(); - - // Wrap uint8 buffer for colors. Note that we need to set normalized=true. - geometry.setAttribute( - "color", - threeColorBufferFromUint8Buffer(message.colors), - ); - addSceneNodeMakeParents( - new SceneNode( - message.name, - (ref) => ( - - ), - () => { - // TODO: we can switch to the react-three-fiber , - // , etc components to avoid manual - // disposal. - geometry.dispose(); - pointCloudMaterial.dispose(); - }, - ), + new SceneNode(message.name, (ref) => ( + val / 255.0, + )} + /> + )), ); return; } diff --git a/src/viser/client/src/WebsocketMessages.tsx b/src/viser/client/src/WebsocketMessages.tsx index 910daa3e5..f9d22326a 100644 --- a/src/viser/client/src/WebsocketMessages.tsx +++ b/src/viser/client/src/WebsocketMessages.tsx @@ -128,6 +128,7 @@ export interface PointCloudMessage { points: Uint8Array; colors: Uint8Array; point_size: number; + point_ball_norm: number; } /** Mesh message. *