From c449702caa067c740b3a8bdf0201d95292d14c42 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 29 Nov 2024 13:33:49 -0800 Subject: [PATCH] Fix Gaussian rendering from virtual cameras (#344) * Fix virtual camera rendering for Gaussians * Nits * Optimization --- src/viser/client/src/App.tsx | 2 +- src/viser/client/src/MessageHandler.tsx | 23 ++ .../client/src/Splatting/GaussianSplats.tsx | 265 ++++++++++++------ 3 files changed, 198 insertions(+), 92 deletions(-) diff --git a/src/viser/client/src/App.tsx b/src/viser/client/src/App.tsx index 717de5cb8..6e9806b09 100644 --- a/src/viser/client/src/App.tsx +++ b/src/viser/client/src/App.tsx @@ -466,12 +466,12 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) { }} > {inView ? null : } - {children} {memoizedCameraControls} + {children} diff --git a/src/viser/client/src/MessageHandler.tsx b/src/viser/client/src/MessageHandler.tsx index 8a26ca877..d4929677d 100644 --- a/src/viser/client/src/MessageHandler.tsx +++ b/src/viser/client/src/MessageHandler.tsx @@ -19,6 +19,7 @@ import { Progress } from "@mantine/core"; import { IconCheck } from "@tabler/icons-react"; import { computeT_threeworld_world } from "./WorldTransformUtils"; import { rootNodeTemplate } from "./SceneTreeState"; +import { GaussianSplatsContext } from "./Splatting/GaussianSplats"; /** Returns a handler for all incoming messages. */ function useMessageHandler() { @@ -521,6 +522,7 @@ export function FrameSynchronizedMessageHandler() { const handleMessage = useMessageHandler(); const viewer = useContext(ViewerContext)!; const messageQueueRef = viewer.messageQueueRef; + const splatContext = React.useContext(GaussianSplatsContext)!; useFrame( () => { @@ -567,6 +569,21 @@ export function FrameSynchronizedMessageHandler() { ), ); + // Update splatting camera if needed. + // We'll back up the current sorted indices, and restore them after rendering. + const splatMeshProps = splatContext.meshPropsRef.current; + const sortedIndicesOrig = + splatMeshProps !== null + ? splatMeshProps.sortedIndexAttribute.array.slice() + : null; + if (splatContext.updateCamera.current !== null) + splatContext.updateCamera.current!( + camera, + targetWidth, + targetHeight, + true, + ); + // Note: We don't need to add the camera to the scene for rendering // The renderer.render() function uses the camera directly // Create a new renderer @@ -583,6 +600,12 @@ export function FrameSynchronizedMessageHandler() { // Render the scene. renderer.render(viewer.sceneRef.current!, camera); + // Restore splatting indices. + if (sortedIndicesOrig !== null && splatMeshProps !== null) { + splatMeshProps.sortedIndexAttribute.array = sortedIndicesOrig; + splatMeshProps.sortedIndexAttribute.needsUpdate = true; + } + // Get the rendered image. viewer.getRenderRequestState.current = "in_progress"; renderer.domElement.toBlob(async (blob) => { diff --git a/src/viser/client/src/Splatting/GaussianSplats.tsx b/src/viser/client/src/Splatting/GaussianSplats.tsx index 2e77374eb..ceabf35d5 100644 --- a/src/viser/client/src/Splatting/GaussianSplats.tsx +++ b/src/viser/client/src/Splatting/GaussianSplats.tsx @@ -22,6 +22,8 @@ * between multiple splat objects. */ +import MakeSorterModulePromise from "./WasmSorter/Sorter.mjs"; + import React from "react"; import * as THREE from "three"; import SplatSortWorker from "./SplatSortWorker?worker"; @@ -64,9 +66,22 @@ function useGaussianSplatStore() { })), )[0]; } -const GaussianSplatsContext = React.createContext | null>(null); + +export const GaussianSplatsContext = React.createContext<{ + useGaussianSplatStore: ReturnType; + updateCamera: React.MutableRefObject< + | null + | (( + camera: THREE.PerspectiveCamera, + width: number, + height: number, + blockingSort: boolean, + ) => void) + >; + meshPropsRef: React.MutableRefObject | null>; +} | null>(null); /**Provider for creating splat rendering context.*/ export function SplatRenderContext({ @@ -75,9 +90,18 @@ export function SplatRenderContext({ children: React.ReactNode; }) { const store = useGaussianSplatStore(); + const numGroups = Object.keys( + store((state) => state.groupBufferFromId), + ).length; return ( - - + + {numGroups > 0 ? : null} {children} ); @@ -251,9 +275,15 @@ export const SplatObject = React.forwardRef< } >(function SplatObject({ buffer }, ref) { const splatContext = React.useContext(GaussianSplatsContext)!; - const setBuffer = splatContext((state) => state.setBuffer); - const removeBuffer = splatContext((state) => state.removeBuffer); - const nodeRefFromId = splatContext((state) => state.nodeRefFromId); + const setBuffer = splatContext.useGaussianSplatStore( + (state) => state.setBuffer, + ); + const removeBuffer = splatContext.useGaussianSplatStore( + (state) => state.removeBuffer, + ); + const nodeRefFromId = splatContext.useGaussianSplatStore( + (state) => state.nodeRefFromId, + ); const name = React.useMemo(() => uuidv4(), [buffer]); const [obj, setRef] = React.useState(null); @@ -281,8 +311,12 @@ export const SplatObject = React.forwardRef< /** External interface. Component should be added to the root of canvas. */ function SplatRenderer() { const splatContext = React.useContext(GaussianSplatsContext)!; - const groupBufferFromId = splatContext((state) => state.groupBufferFromId); - const nodeRefFromId = splatContext((state) => state.nodeRefFromId); + const groupBufferFromId = splatContext.useGaussianSplatStore( + (state) => state.groupBufferFromId, + ); + const nodeRefFromId = splatContext.useGaussianSplatStore( + (state) => state.nodeRefFromId, + ); // Consolidate Gaussian groups into a single buffer. const merged = mergeGaussianGroups(groupBufferFromId); @@ -290,6 +324,7 @@ function SplatRenderer() { merged.gaussianBuffer, merged.numGroups, ); + splatContext.meshPropsRef.current = meshProps; // Create sorting worker. const sortWorker = new SplatSortWorker(); @@ -310,7 +345,6 @@ function SplatRenderer() { function postToWorker(message: SorterWorkerIncoming) { sortWorker.postMessage(message); } - postToWorker({ setBuffer: merged.gaussianBuffer, setGroupIndices: merged.groupIndices, @@ -337,96 +371,145 @@ function SplatRenderer() { .slice() .fill(0); const prevVisibles: boolean[] = []; - useFrame((state, delta) => { - const mesh = meshRef.current; - if (mesh === null || sortWorker === null) return; - // Update camera parameter uniforms. - const dpr = state.viewport.dpr; - const fovY = - ((state.camera as THREE.PerspectiveCamera).fov * Math.PI) / 180.0; - const fovX = 2 * Math.atan(Math.tan(fovY / 2) * state.viewport.aspect); - const fy = (dpr * state.size.height) / (2 * Math.tan(fovY / 2)); - const fx = (dpr * state.size.width) / (2 * Math.tan(fovX / 2)); + // Make local sorter. This will be used for blocking sorts, eg for rendering + // from virtual cameras. + const SorterRef = React.useRef(null); + React.useEffect(() => { + (async () => { + SorterRef.current = new (await MakeSorterModulePromise()).Sorter( + merged.gaussianBuffer, + merged.groupIndices, + ); + })(); + }, [merged.gaussianBuffer, merged.groupIndices]); + + const updateCamera = React.useCallback( + function updateCamera( + camera: THREE.PerspectiveCamera, + width: number, + height: number, + blockingSort: boolean, + ) { + // Update camera parameter uniforms. + const fovY = ((camera as THREE.PerspectiveCamera).fov * Math.PI) / 180.0; + + const aspect = width / height; + const fovX = 2 * Math.atan(Math.tan(fovY / 2) * aspect); + const fy = height / (2 * Math.tan(fovY / 2)); + const fx = width / (2 * Math.tan(fovX / 2)); + + if (meshProps.material === undefined) return; + + const uniforms = meshProps.material.uniforms; + uniforms.focal.value = [fx, fy]; + uniforms.near.value = camera.near; + uniforms.far.value = camera.far; + uniforms.viewport.value = [width, height]; + + // Update group transforms. + camera.updateMatrixWorld(); + const T_camera_world = camera.matrixWorldInverse; + const groupVisibles: boolean[] = []; + let visibilitiesChanged = false; + for (const [groupIndex, name] of Object.keys( + groupBufferFromId, + ).entries()) { + const node = nodeRefFromId.current[name]; + if (node === undefined) continue; + tmpT_camera_group.copy(T_camera_world).multiply(node.matrixWorld); + const colMajorElements = tmpT_camera_group.elements; + Tz_camera_groups.set( + [ + colMajorElements[2], + colMajorElements[6], + colMajorElements[10], + colMajorElements[14], + ], + groupIndex * 4, + ); + const rowMajorElements = tmpT_camera_group.transpose().elements; + meshProps.rowMajorT_camera_groups.set( + rowMajorElements.slice(0, 12), + groupIndex * 12, + ); + + // Determine visibility. If the parent has unmountWhenInvisible=true, the + // first frame after showing a hidden parent can have visible=true with + // an incorrect matrixWorld transform. There might be a better fix, but + // `prevVisible` is an easy workaround for this. + let visibleNow = node.visible && node.parent !== null; + if (visibleNow) { + node.traverseAncestors((ancestor) => { + visibleNow = visibleNow && ancestor.visible; + }); + } + groupVisibles.push(visibleNow && prevVisibles[groupIndex] === true); + if (prevVisibles[groupIndex] !== visibleNow) { + prevVisibles[groupIndex] = visibleNow; + visibilitiesChanged = true; + } + } + + const groupsMovedWrtCam = !meshProps.rowMajorT_camera_groups.every( + (v, i) => v === prevRowMajorT_camera_groups[i], + ); + + if (groupsMovedWrtCam) { + // Gaussians need to be re-sorted. + if (blockingSort && SorterRef.current !== null) { + const sortedIndices = SorterRef.current.sort( + Tz_camera_groups, + ) as Uint32Array; + meshProps.sortedIndexAttribute.set(sortedIndices); + meshProps.sortedIndexAttribute.needsUpdate = true; + } else { + postToWorker({ + setTz_camera_groups: Tz_camera_groups, + }); + } + } + if (groupsMovedWrtCam || visibilitiesChanged) { + // If a group is not visible, we'll throw it off the screen with some Big + // Numbers. It's important that this only impacts the coordinates used + // for the shader and not for the sorter; that way when we "show" a group + // of Gaussians the correct rendering order is immediately available. + for (const [i, visible] of groupVisibles.entries()) { + if (!visible) { + meshProps.rowMajorT_camera_groups[i * 12 + 3] = 1e10; + meshProps.rowMajorT_camera_groups[i * 12 + 7] = 1e10; + meshProps.rowMajorT_camera_groups[i * 12 + 11] = 1e10; + } + } + prevRowMajorT_camera_groups.set(meshProps.rowMajorT_camera_groups); + meshProps.textureT_camera_groups.needsUpdate = true; + } + }, + [meshProps], + ); + splatContext.updateCamera.current = updateCamera; - if (meshProps.material === undefined) return; + useFrame((state, delta) => { + const mesh = meshRef.current; + if ( + mesh === null || + sortWorker === null || + meshProps.rowMajorT_camera_groups.length === 0 + ) + return; const uniforms = meshProps.material.uniforms; uniforms.transitionInState.value = Math.min( uniforms.transitionInState.value + delta * 2.0, 1.0, ); - uniforms.focal.value = [fx, fy]; - uniforms.near.value = state.camera.near; - uniforms.far.value = state.camera.far; - uniforms.viewport.value = [state.size.width * dpr, state.size.height * dpr]; - - // Update group transforms. - const T_camera_world = state.camera.matrixWorldInverse; - const groupVisibles: boolean[] = []; - let visibilitiesChanged = false; - for (const [groupIndex, name] of Object.keys(groupBufferFromId).entries()) { - const node = nodeRefFromId.current[name]; - if (node === undefined) continue; - tmpT_camera_group.copy(T_camera_world).multiply(node.matrixWorld); - const colMajorElements = tmpT_camera_group.elements; - Tz_camera_groups.set( - [ - colMajorElements[2], - colMajorElements[6], - colMajorElements[10], - colMajorElements[14], - ], - groupIndex * 4, - ); - const rowMajorElements = tmpT_camera_group.transpose().elements; - meshProps.rowMajorT_camera_groups.set( - rowMajorElements.slice(0, 12), - groupIndex * 12, - ); - // Determine visibility. If the parent has unmountWhenInvisible=true, the - // first frame after showing a hidden parent can have visible=true with - // an incorrect matrixWorld transform. There might be a better fix, but - // `prevVisible` is an easy workaround for this. - let visibleNow = node.visible && node.parent !== null; - if (visibleNow) { - node.traverseAncestors((ancestor) => { - visibleNow = visibleNow && ancestor.visible; - }); - } - groupVisibles.push(visibleNow && prevVisibles[groupIndex] === true); - if (prevVisibles[groupIndex] !== visibleNow) { - prevVisibles[groupIndex] = visibleNow; - visibilitiesChanged = true; - } - } - - const groupsMovedWrtCam = !meshProps.rowMajorT_camera_groups.every( - (v, i) => v === prevRowMajorT_camera_groups[i], + updateCamera( + state.camera as THREE.PerspectiveCamera, + state.viewport.dpr * state.size.width, + state.viewport.dpr * state.size.height, + false /* blockingSort */, ); - - if (groupsMovedWrtCam) { - // Gaussians need to be re-sorted. - postToWorker({ - setTz_camera_groups: Tz_camera_groups, - }); - } - if (groupsMovedWrtCam || visibilitiesChanged) { - // If a group is not visible, we'll throw it off the screen with some Big - // Numbers. It's important that this only impacts the coordinates used - // for the shader and not for the sorter; that way when we "show" a group - // of Gaussians the correct rendering order is immediately available. - for (const [i, visible] of groupVisibles.entries()) { - if (!visible) { - meshProps.rowMajorT_camera_groups[i * 12 + 3] = 1e10; - meshProps.rowMajorT_camera_groups[i * 12 + 7] = 1e10; - meshProps.rowMajorT_camera_groups[i * 12 + 11] = 1e10; - } - } - prevRowMajorT_camera_groups.set(meshProps.rowMajorT_camera_groups); - meshProps.textureT_camera_groups.needsUpdate = true; - } }, -100 /* This should be called early to reduce group transform artifacts. */); return (