Skip to content

Commit

Permalink
Improved point cloud shader (nerfstudio-project#158)
Browse files Browse the repository at this point in the history
* Improved point cloud shader

* Resolve eslint
  • Loading branch information
brentyi authored Jan 9, 2024
1 parent 128d395 commit 2de1289
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 48 deletions.
1 change: 1 addition & 0 deletions examples/07_record3d_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _(_) -> None:
points=position,
colors=color,
point_size=0.01,
point_shape="rounded",
)

# Place the frustum.
Expand Down
11 changes: 11 additions & 0 deletions src/viser/_message_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,9 @@ def add_point_cloud(
points: onp.ndarray,
colors: onp.ndarray | Tuple[float, float, float],
point_size: float = 0.1,
point_shape: Literal[
"square", "diamond", "circle", "rounded", "sparkle"
] = "square",
wxyz: Tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0),
position: Tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0),
visible: bool = True,
Expand All @@ -672,6 +675,7 @@ def add_point_cloud(
points: Location of points. Should have shape (N, 3).
colors: Colors of points. Should have shape (N, 3) or (3,).
point_size: Size of each point.
point_shape: Shape to draw each point.
wxyz: Quaternion rotation to parent frame from local frame (R_pl).
position: Translation to parent frame from local frame (t_pl).
visible: Whether or not this scene node is initially visible.
Expand All @@ -697,6 +701,13 @@ def add_point_cloud(
points=points.astype(onp.float32),
colors=colors_cast,
point_size=point_size,
point_ball_norm={
"square": 0.0,
"diamond": 1.0,
"circle": 2.0,
"rounded": 3.0,
"sparkle": 0.6,
}[point_shape],
)
)
return PointCloudHandle._make(self, name, wxyz, position, visible)
Expand Down
3 changes: 2 additions & 1 deletion src/viser/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class PointCloudMessage(Message):
name: str
points: onpt.NDArray[onp.float32]
colors: onpt.NDArray[onp.uint8]
point_size: float = 0.1
point_size: float
point_ball_norm: float

def __post_init__(self):
# Check shapes.
Expand Down
89 changes: 87 additions & 2 deletions src/viser/client/src/ThreeAssets.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Instance, Instances } from "@react-three/drei";
import { createPortal, useFrame } from "@react-three/fiber";
import { Instance, Instances, shaderMaterial } from "@react-three/drei";
import { createPortal, useFrame, useThree } from "@react-three/fiber";
import { Outlines } from "./Outlines";
import React from "react";
import * as THREE from "three";
Expand Down Expand Up @@ -47,6 +47,91 @@ type AllPossibleThreeJSMaterials =
const originGeom = new THREE.SphereGeometry(1.0);
const originMaterial = new THREE.MeshBasicMaterial({ color: 0xecec00 });

const PointCloudMaterial = /* @__PURE__ */ shaderMaterial(
{ scale: 1.0, point_ball_norm: 0.0 },
`
varying vec3 vPosition;
varying vec3 vColor; // in the vertex shader
uniform float scale;
void main() {
vPosition = position;
vColor = color;
vec4 world_pos = modelViewMatrix * vec4(position, 1.0);
gl_Position = projectionMatrix * world_pos;
gl_PointSize = (scale / -world_pos.z);
}
`,
`varying vec3 vPosition;
varying vec3 vColor;
uniform float point_ball_norm;
void main() {
if (!isinf(point_ball_norm)) {
float r = pow(
pow(abs(gl_PointCoord.x - 0.5), point_ball_norm)
+ pow(abs(gl_PointCoord.y - 0.5), point_ball_norm),
1.0 / point_ball_norm);
if (r > 0.5) discard;
}
gl_FragColor = vec4(vColor, 1.0);
}
`,
);

export const PointCloud = React.forwardRef<
THREE.Points,
{
pointSize: number;
/** We visualize each point as a 2D ball, which is defined by some norm. */
pointBallNorm: number;
points: Float32Array;
colors: Float32Array;
}
>(function PointCloud(props, ref) {
const { gl, camera } = useThree();

const geometry = new THREE.BufferGeometry();
geometry.setAttribute(
"position",
new THREE.Float32BufferAttribute(props.points, 3),
);
geometry.computeBoundingSphere();
geometry.setAttribute(
"color",
new THREE.Float32BufferAttribute(props.colors, 3),
);

const [material] = React.useState(
() => new PointCloudMaterial({ vertexColors: true }),
);
material.uniforms.scale.value = 10.0;
material.uniforms.point_ball_norm.value = props.pointBallNorm;

React.useEffect(() => {
return () => {
material.dispose();
geometry.dispose();
};
});

const rendererSize = new THREE.Vector2();
useFrame(() => {
// Match point scale to behavior of THREE.PointsMaterial().
if (material === undefined) return;
// point px height / actual height = point meters height / frustum meters height
// frustum meters height = math.tan(fov / 2.0) * z
// point px height = (point meters height / math.tan(fov / 2.0) * actual height) / z
material.uniforms.scale.value =
(props.pointSize /
Math.tan(
(((camera as THREE.PerspectiveCamera).fov / 180.0) * Math.PI) / 2.0,
)) *
gl.getSize(rendererSize).height;
});
return <points ref={ref} geometry={geometry} material={material} />;
});

/** Component for rendering the contents of GLB files. */
export const GlbAsset = React.forwardRef<
THREE.Group,
Expand Down
64 changes: 19 additions & 45 deletions src/viser/client/src/WebsocketInterface.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
CoordinateFrame,
GlbAsset,
OutlinesIfHovered,
PointCloud,
} from "./ThreeAssets";
import {
FileDownloadPart,
Expand Down Expand Up @@ -195,52 +196,25 @@ function useMessageHandler() {

// Add a point cloud.
case "PointCloudMessage": {
const geometry = new THREE.BufferGeometry();
const pointCloudMaterial = new THREE.PointsMaterial({
size: message.point_size,
vertexColors: true,
toneMapped: false,
});

// Reinterpret cast: uint8 buffer => float32 for positions.
geometry.setAttribute(
"position",
new THREE.Float32BufferAttribute(
new Float32Array(
message.points.buffer.slice(
message.points.byteOffset,
message.points.byteOffset + message.points.byteLength,
),
),
3,
),
);
geometry.computeBoundingSphere();

// Wrap uint8 buffer for colors. Note that we need to set normalized=true.
geometry.setAttribute(
"color",
threeColorBufferFromUint8Buffer(message.colors),
);

addSceneNodeMakeParents(
new SceneNode<THREE.Points>(
message.name,
(ref) => (
<points
ref={ref}
geometry={geometry}
material={pointCloudMaterial}
/>
),
() => {
// TODO: we can switch to the react-three-fiber <bufferGeometry />,
// <pointsMaterial />, etc components to avoid manual
// disposal.
geometry.dispose();
pointCloudMaterial.dispose();
},
),
new SceneNode<THREE.Points>(message.name, (ref) => (
<PointCloud
ref={ref}
pointSize={message.point_size}
pointBallNorm={message.point_ball_norm}
points={
new Float32Array(
message.points.buffer.slice(
message.points.byteOffset,
message.points.byteOffset + message.points.byteLength,
),
)
}
colors={new Float32Array(message.colors).map(
(val) => val / 255.0,
)}
/>
)),
);
return;
}
Expand Down
1 change: 1 addition & 0 deletions src/viser/client/src/WebsocketMessages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ export interface PointCloudMessage {
points: Uint8Array;
colors: Uint8Array;
point_size: number;
point_ball_norm: number;
}
/** Mesh message.
*
Expand Down

0 comments on commit 2de1289

Please sign in to comment.