Skip to content

Commit

Permalink
Merge branch 'main' into brent/gui_images
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi authored Dec 2, 2024
2 parents 7833f2f + c449702 commit 299ece1
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 92 deletions.
2 changes: 1 addition & 1 deletion src/viser/client/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -466,12 +466,12 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) {
}}
>
{inView ? null : <DisableRender />}
{children}
<BackgroundImage />
<AdaptiveDpr />
<SceneContextSetter />
{memoizedCameraControls}
<SplatRenderContext>
{children}
<SceneNodeThreeObject name="" parent={null} />
</SplatRenderContext>
<DefaultLights />
Expand Down
23 changes: 23 additions & 0 deletions src/viser/client/src/MessageHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { Progress } from "@mantine/core";
import { IconCheck } from "@tabler/icons-react";
import { computeT_threeworld_world } from "./WorldTransformUtils";
import { rootNodeTemplate } from "./SceneTreeState";
import { GaussianSplatsContext } from "./Splatting/GaussianSplats";

/** Returns a handler for all incoming messages. */
function useMessageHandler() {
Expand Down Expand Up @@ -521,6 +522,7 @@ export function FrameSynchronizedMessageHandler() {
const handleMessage = useMessageHandler();
const viewer = useContext(ViewerContext)!;
const messageQueueRef = viewer.messageQueueRef;
const splatContext = React.useContext(GaussianSplatsContext)!;

useFrame(
() => {
Expand Down Expand Up @@ -567,6 +569,21 @@ export function FrameSynchronizedMessageHandler() {
),
);

// Update splatting camera if needed.
// We'll back up the current sorted indices, and restore them after rendering.
const splatMeshProps = splatContext.meshPropsRef.current;
const sortedIndicesOrig =
splatMeshProps !== null
? splatMeshProps.sortedIndexAttribute.array.slice()
: null;
if (splatContext.updateCamera.current !== null)
splatContext.updateCamera.current!(
camera,
targetWidth,
targetHeight,
true,
);

// Note: We don't need to add the camera to the scene for rendering
// The renderer.render() function uses the camera directly
// Create a new renderer
Expand All @@ -583,6 +600,12 @@ export function FrameSynchronizedMessageHandler() {
// Render the scene.
renderer.render(viewer.sceneRef.current!, camera);

// Restore splatting indices.
if (sortedIndicesOrig !== null && splatMeshProps !== null) {
splatMeshProps.sortedIndexAttribute.array = sortedIndicesOrig;
splatMeshProps.sortedIndexAttribute.needsUpdate = true;
}

// Get the rendered image.
viewer.getRenderRequestState.current = "in_progress";
renderer.domElement.toBlob(async (blob) => {
Expand Down
265 changes: 174 additions & 91 deletions src/viser/client/src/Splatting/GaussianSplats.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
* between multiple splat objects.
*/

import MakeSorterModulePromise from "./WasmSorter/Sorter.mjs";

import React from "react";
import * as THREE from "three";
import SplatSortWorker from "./SplatSortWorker?worker";
Expand Down Expand Up @@ -64,9 +66,22 @@ function useGaussianSplatStore() {
})),
)[0];
}
const GaussianSplatsContext = React.createContext<ReturnType<
typeof useGaussianSplatStore
> | null>(null);

export const GaussianSplatsContext = React.createContext<{
useGaussianSplatStore: ReturnType<typeof useGaussianSplatStore>;
updateCamera: React.MutableRefObject<
| null
| ((
camera: THREE.PerspectiveCamera,
width: number,
height: number,
blockingSort: boolean,
) => void)
>;
meshPropsRef: React.MutableRefObject<ReturnType<
typeof useGaussianMeshProps
> | null>;
} | null>(null);

/**Provider for creating splat rendering context.*/
export function SplatRenderContext({
Expand All @@ -75,9 +90,18 @@ export function SplatRenderContext({
children: React.ReactNode;
}) {
const store = useGaussianSplatStore();
const numGroups = Object.keys(
store((state) => state.groupBufferFromId),
).length;
return (
<GaussianSplatsContext.Provider value={store}>
<SplatRenderer />
<GaussianSplatsContext.Provider
value={{
useGaussianSplatStore: store,
updateCamera: React.useRef(null),
meshPropsRef: React.useRef(null),
}}
>
{numGroups > 0 ? <SplatRenderer /> : null}
{children}
</GaussianSplatsContext.Provider>
);
Expand Down Expand Up @@ -251,9 +275,15 @@ export const SplatObject = React.forwardRef<
}
>(function SplatObject({ buffer }, ref) {
const splatContext = React.useContext(GaussianSplatsContext)!;
const setBuffer = splatContext((state) => state.setBuffer);
const removeBuffer = splatContext((state) => state.removeBuffer);
const nodeRefFromId = splatContext((state) => state.nodeRefFromId);
const setBuffer = splatContext.useGaussianSplatStore(
(state) => state.setBuffer,
);
const removeBuffer = splatContext.useGaussianSplatStore(
(state) => state.removeBuffer,
);
const nodeRefFromId = splatContext.useGaussianSplatStore(
(state) => state.nodeRefFromId,
);
const name = React.useMemo(() => uuidv4(), [buffer]);

const [obj, setRef] = React.useState<THREE.Group | null>(null);
Expand Down Expand Up @@ -281,15 +311,20 @@ export const SplatObject = React.forwardRef<
/** External interface. Component should be added to the root of canvas. */
function SplatRenderer() {
const splatContext = React.useContext(GaussianSplatsContext)!;
const groupBufferFromId = splatContext((state) => state.groupBufferFromId);
const nodeRefFromId = splatContext((state) => state.nodeRefFromId);
const groupBufferFromId = splatContext.useGaussianSplatStore(
(state) => state.groupBufferFromId,
);
const nodeRefFromId = splatContext.useGaussianSplatStore(
(state) => state.nodeRefFromId,
);

// Consolidate Gaussian groups into a single buffer.
const merged = mergeGaussianGroups(groupBufferFromId);
const meshProps = useGaussianMeshProps(
merged.gaussianBuffer,
merged.numGroups,
);
splatContext.meshPropsRef.current = meshProps;

// Create sorting worker.
const sortWorker = new SplatSortWorker();
Expand All @@ -310,7 +345,6 @@ function SplatRenderer() {
function postToWorker(message: SorterWorkerIncoming) {
sortWorker.postMessage(message);
}

postToWorker({
setBuffer: merged.gaussianBuffer,
setGroupIndices: merged.groupIndices,
Expand All @@ -337,96 +371,145 @@ function SplatRenderer() {
.slice()
.fill(0);
const prevVisibles: boolean[] = [];
useFrame((state, delta) => {
const mesh = meshRef.current;
if (mesh === null || sortWorker === null) return;

// Update camera parameter uniforms.
const dpr = state.viewport.dpr;
const fovY =
((state.camera as THREE.PerspectiveCamera).fov * Math.PI) / 180.0;
const fovX = 2 * Math.atan(Math.tan(fovY / 2) * state.viewport.aspect);
const fy = (dpr * state.size.height) / (2 * Math.tan(fovY / 2));
const fx = (dpr * state.size.width) / (2 * Math.tan(fovX / 2));
// Make local sorter. This will be used for blocking sorts, eg for rendering
// from virtual cameras.
const SorterRef = React.useRef<any>(null);
React.useEffect(() => {
(async () => {
SorterRef.current = new (await MakeSorterModulePromise()).Sorter(
merged.gaussianBuffer,
merged.groupIndices,
);
})();
}, [merged.gaussianBuffer, merged.groupIndices]);

const updateCamera = React.useCallback(
function updateCamera(
camera: THREE.PerspectiveCamera,
width: number,
height: number,
blockingSort: boolean,
) {
// Update camera parameter uniforms.
const fovY = ((camera as THREE.PerspectiveCamera).fov * Math.PI) / 180.0;

const aspect = width / height;
const fovX = 2 * Math.atan(Math.tan(fovY / 2) * aspect);
const fy = height / (2 * Math.tan(fovY / 2));
const fx = width / (2 * Math.tan(fovX / 2));

if (meshProps.material === undefined) return;

const uniforms = meshProps.material.uniforms;
uniforms.focal.value = [fx, fy];
uniforms.near.value = camera.near;
uniforms.far.value = camera.far;
uniforms.viewport.value = [width, height];

// Update group transforms.
camera.updateMatrixWorld();
const T_camera_world = camera.matrixWorldInverse;
const groupVisibles: boolean[] = [];
let visibilitiesChanged = false;
for (const [groupIndex, name] of Object.keys(
groupBufferFromId,
).entries()) {
const node = nodeRefFromId.current[name];
if (node === undefined) continue;
tmpT_camera_group.copy(T_camera_world).multiply(node.matrixWorld);
const colMajorElements = tmpT_camera_group.elements;
Tz_camera_groups.set(
[
colMajorElements[2],
colMajorElements[6],
colMajorElements[10],
colMajorElements[14],
],
groupIndex * 4,
);
const rowMajorElements = tmpT_camera_group.transpose().elements;
meshProps.rowMajorT_camera_groups.set(
rowMajorElements.slice(0, 12),
groupIndex * 12,
);

// Determine visibility. If the parent has unmountWhenInvisible=true, the
// first frame after showing a hidden parent can have visible=true with
// an incorrect matrixWorld transform. There might be a better fix, but
// `prevVisible` is an easy workaround for this.
let visibleNow = node.visible && node.parent !== null;
if (visibleNow) {
node.traverseAncestors((ancestor) => {
visibleNow = visibleNow && ancestor.visible;
});
}
groupVisibles.push(visibleNow && prevVisibles[groupIndex] === true);
if (prevVisibles[groupIndex] !== visibleNow) {
prevVisibles[groupIndex] = visibleNow;
visibilitiesChanged = true;
}
}

const groupsMovedWrtCam = !meshProps.rowMajorT_camera_groups.every(
(v, i) => v === prevRowMajorT_camera_groups[i],
);

if (groupsMovedWrtCam) {
// Gaussians need to be re-sorted.
if (blockingSort && SorterRef.current !== null) {
const sortedIndices = SorterRef.current.sort(
Tz_camera_groups,
) as Uint32Array;
meshProps.sortedIndexAttribute.set(sortedIndices);
meshProps.sortedIndexAttribute.needsUpdate = true;
} else {
postToWorker({
setTz_camera_groups: Tz_camera_groups,
});
}
}
if (groupsMovedWrtCam || visibilitiesChanged) {
// If a group is not visible, we'll throw it off the screen with some Big
// Numbers. It's important that this only impacts the coordinates used
// for the shader and not for the sorter; that way when we "show" a group
// of Gaussians the correct rendering order is immediately available.
for (const [i, visible] of groupVisibles.entries()) {
if (!visible) {
meshProps.rowMajorT_camera_groups[i * 12 + 3] = 1e10;
meshProps.rowMajorT_camera_groups[i * 12 + 7] = 1e10;
meshProps.rowMajorT_camera_groups[i * 12 + 11] = 1e10;
}
}
prevRowMajorT_camera_groups.set(meshProps.rowMajorT_camera_groups);
meshProps.textureT_camera_groups.needsUpdate = true;
}
},
[meshProps],
);
splatContext.updateCamera.current = updateCamera;

if (meshProps.material === undefined) return;
useFrame((state, delta) => {
const mesh = meshRef.current;
if (
mesh === null ||
sortWorker === null ||
meshProps.rowMajorT_camera_groups.length === 0
)
return;

const uniforms = meshProps.material.uniforms;
uniforms.transitionInState.value = Math.min(
uniforms.transitionInState.value + delta * 2.0,
1.0,
);
uniforms.focal.value = [fx, fy];
uniforms.near.value = state.camera.near;
uniforms.far.value = state.camera.far;
uniforms.viewport.value = [state.size.width * dpr, state.size.height * dpr];

// Update group transforms.
const T_camera_world = state.camera.matrixWorldInverse;
const groupVisibles: boolean[] = [];
let visibilitiesChanged = false;
for (const [groupIndex, name] of Object.keys(groupBufferFromId).entries()) {
const node = nodeRefFromId.current[name];
if (node === undefined) continue;
tmpT_camera_group.copy(T_camera_world).multiply(node.matrixWorld);
const colMajorElements = tmpT_camera_group.elements;
Tz_camera_groups.set(
[
colMajorElements[2],
colMajorElements[6],
colMajorElements[10],
colMajorElements[14],
],
groupIndex * 4,
);
const rowMajorElements = tmpT_camera_group.transpose().elements;
meshProps.rowMajorT_camera_groups.set(
rowMajorElements.slice(0, 12),
groupIndex * 12,
);

// Determine visibility. If the parent has unmountWhenInvisible=true, the
// first frame after showing a hidden parent can have visible=true with
// an incorrect matrixWorld transform. There might be a better fix, but
// `prevVisible` is an easy workaround for this.
let visibleNow = node.visible && node.parent !== null;
if (visibleNow) {
node.traverseAncestors((ancestor) => {
visibleNow = visibleNow && ancestor.visible;
});
}
groupVisibles.push(visibleNow && prevVisibles[groupIndex] === true);
if (prevVisibles[groupIndex] !== visibleNow) {
prevVisibles[groupIndex] = visibleNow;
visibilitiesChanged = true;
}
}

const groupsMovedWrtCam = !meshProps.rowMajorT_camera_groups.every(
(v, i) => v === prevRowMajorT_camera_groups[i],
updateCamera(
state.camera as THREE.PerspectiveCamera,
state.viewport.dpr * state.size.width,
state.viewport.dpr * state.size.height,
false /* blockingSort */,
);

if (groupsMovedWrtCam) {
// Gaussians need to be re-sorted.
postToWorker({
setTz_camera_groups: Tz_camera_groups,
});
}
if (groupsMovedWrtCam || visibilitiesChanged) {
// If a group is not visible, we'll throw it off the screen with some Big
// Numbers. It's important that this only impacts the coordinates used
// for the shader and not for the sorter; that way when we "show" a group
// of Gaussians the correct rendering order is immediately available.
for (const [i, visible] of groupVisibles.entries()) {
if (!visible) {
meshProps.rowMajorT_camera_groups[i * 12 + 3] = 1e10;
meshProps.rowMajorT_camera_groups[i * 12 + 7] = 1e10;
meshProps.rowMajorT_camera_groups[i * 12 + 11] = 1e10;
}
}
prevRowMajorT_camera_groups.set(meshProps.rowMajorT_camera_groups);
meshProps.textureT_camera_groups.needsUpdate = true;
}
}, -100 /* This should be called early to reduce group transform artifacts. */);

return (
Expand Down

0 comments on commit 299ece1

Please sign in to comment.