Skip to content

Commit

Permalink
Rename pt 2
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 10, 2023
1 parent 2f88df7 commit 1fe3fb5
Show file tree
Hide file tree
Showing 7 changed files with 528 additions and 0 deletions.
182 changes: 182 additions & 0 deletions src/viser/client/src/Splatting/GaussianSplats.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import React, { useEffect } from "react";
import * as THREE from "three";
import SplatSortWorker from "./SplatSortWorker?worker";
import { fragmentShaderSource, vertexShaderSource } from "./Shaders";
import { useFrame } from "@react-three/fiber";
import { GaussianBuffersSplitCov } from "./SplatSortWorker";

export type GaussianBuffers = {
// (N, 3)
centers: Float32Array;
// (N, 3)
rgbs: Float32Array;
// (N, 1)
opacities: Float32Array;
// (N, 6)
covariancesTriu: Float32Array;
};

export default function GaussianSplats({
buffers,
}: {
buffers: GaussianBuffers;
}) {
// Create buffer geometry + setter function.
const [geometry, setSortedBuffers] = React.useMemo(() => {
const geometry = new THREE.InstancedBufferGeometry();
const numGaussians = buffers.centers.length / 3;
geometry.instanceCount = numGaussians;

// Each Gaussian will be drawn as a quadrilateral.
geometry.setIndex(
new THREE.BufferAttribute(new Uint32Array([0, 2, 1, 0, 3, 2]), 1),
);
geometry.setAttribute(
"position",
new THREE.BufferAttribute(
new Float32Array([-2, -2, 2, -2, 2, 2, -2, 2]),
2,
),
);

// Create attributes.
const centerAttribute = new THREE.InstancedBufferAttribute(
new Float32Array(numGaussians * 3),
3,
);
const rgbAttribute = new THREE.InstancedBufferAttribute(
new Float32Array(numGaussians * 3),
3,
);
const opacityAttribute = new THREE.InstancedBufferAttribute(
new Float32Array(numGaussians),
1,
);
const covAAttribute = new THREE.InstancedBufferAttribute(
new Float32Array(numGaussians * 3),
3,
);
const covBAttribute = new THREE.InstancedBufferAttribute(
new Float32Array(numGaussians * 3),
3,
);

geometry.setAttribute("center", centerAttribute);
geometry.setAttribute("rgb", rgbAttribute);
geometry.setAttribute("opacity", opacityAttribute);
geometry.setAttribute("covA", covAAttribute);
geometry.setAttribute("covB", covBAttribute);

return [
geometry,
(sortedBuffers: GaussianBuffersSplitCov) => {
centerAttribute.set(sortedBuffers.centers);
rgbAttribute.set(sortedBuffers.rgbs);
opacityAttribute.set(sortedBuffers.opacities);
covAAttribute.set(sortedBuffers.covA);
covBAttribute.set(sortedBuffers.covB);

centerAttribute.needsUpdate = true;
rgbAttribute.needsUpdate = true;
opacityAttribute.needsUpdate = true;
covAAttribute.needsUpdate = true;
covBAttribute.needsUpdate = true;
},
];
}, []);

// Update shader uniforms.
const shaderMaterial = React.useMemo(() => {
console.log("making material");
return new THREE.RawShaderMaterial({
fragmentShader: fragmentShaderSource,
vertexShader: vertexShaderSource,
uniforms: {
viewport: { value: null },
focal: { value: null },
},
depthTest: true,
depthWrite: false,
transparent: true,
});
}, []);
React.useEffect(() => {
return () => shaderMaterial.dispose();
}, []);

useFrame((state) => {
const dpr = state.viewport.dpr;
const fovY =
((state.camera as THREE.PerspectiveCamera).fov * Math.PI) / 180.0;
const fovX = 2 * Math.atan(Math.tan(fovY / 2) * state.viewport.aspect);
const fy = (dpr * state.size.height) / (2 * Math.tan(fovY / 2));
const fx = (dpr * state.size.width) / (2 * Math.tan(fovX / 2));

shaderMaterial.uniforms.focal.value = [fx, fy];
shaderMaterial.uniforms.viewport.value = [
state.size.width * dpr,
state.size.height * dpr,
];
});

// Create worker for sorting Gaussians.
const splatSortWorkerRef = React.useRef<Worker | null>(null);
useEffect(() => {
const sortWorker = new SplatSortWorker();
sortWorker.postMessage({
setBuffers: splitCovariances(buffers),
});
splatSortWorkerRef.current = sortWorker;

sortWorker.onmessage = (e) => {
setSortedBuffers(e.data as GaussianBuffersSplitCov);
};

// Close the worker when done.
return () => sortWorker.postMessage({ close: true });
}, [buffers]);

// Synchronize view projection matrix with sort worker.
const meshRef = React.useRef<THREE.Mesh>(null);
const prevViewProj = React.useRef<THREE.Matrix4>();
useFrame((state) => {
const mesh = meshRef.current;
const sortWorker = splatSortWorkerRef.current;
if (mesh === null || sortWorker === null) return;

// Compute view projection matrix.
const viewProj = new THREE.Matrix4()
.multiply(state.camera.projectionMatrix)
.multiply(state.camera.matrixWorldInverse)
.multiply(mesh.matrixWorld);

// If changed, use projection matrix to sort Gaussians.
if (
prevViewProj.current === undefined ||
!viewProj.equals(prevViewProj.current)
) {
sortWorker.postMessage({ setViewProj: viewProj.elements });
prevViewProj.current = viewProj;
}
});

return <mesh ref={meshRef} geometry={geometry} material={shaderMaterial} />;
}

/** Split upper-triangular terms (6D) of covariance into pair of 3D terms. This
* lets us pass vec3 arrays into our shader. */
function splitCovariances(buffers: GaussianBuffers) {
const covA = new Float32Array(buffers.covariancesTriu.length / 2);
const covB = new Float32Array(buffers.covariancesTriu.length / 2);
for (let i = 0; i < covA.length; i++) {
covA[i] = buffers.covariancesTriu[Math.floor(i / 3) * 6 + (i % 3)];
covB[i] = buffers.covariancesTriu[Math.floor(i / 3) * 6 + (i % 3) + 3];
}
return {
centers: buffers.centers,
rgbs: buffers.rgbs,
opacities: buffers.opacities,
covA: covA,
covB: covB,
};
}
94 changes: 94 additions & 0 deletions src/viser/client/src/Splatting/Shaders.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/** Shaders for Gaussian splatting.
*
* These are adapted from Kevin Kwok, with only minor modifications.
*
* https://github.com/antimatter15/splat/
*/

export const vertexShaderSource = `
precision mediump float;
attribute vec2 position;
attribute vec3 rgb;
attribute float opacity;
attribute vec3 center;
attribute vec3 covA;
attribute vec3 covB;
uniform mat4 projectionMatrix, modelViewMatrix;
uniform vec2 focal;
uniform vec2 viewport;
varying vec3 vRgb;
varying float vOpacity;
varying vec2 vPosition;
mat3 transpose(mat3 m) {
return mat3(
m[0][0], m[1][0], m[2][0],
m[0][1], m[1][1], m[2][1],
m[0][2], m[1][2], m[2][2]
);
}
void main () {
// Get center wrt camera. modelViewMatrix is T_cam_world.
vec4 c_cam = modelViewMatrix * vec4(center, 1);
vec4 pos2d = projectionMatrix * c_cam;
// Splat covariance.
mat3 cov3d = mat3(
covA.x, covA.y, covA.z,
covA.y, covB.x, covB.y,
covA.z, covB.y, covB.z
);
mat3 J = mat3(
// Note that matrices are column-major.
focal.x / c_cam.z, 0., 0.0,
0., focal.y / c_cam.z, 0.0,
-(focal.x * c_cam.x) / (c_cam.z * c_cam.z), -(focal.y * c_cam.y) / (c_cam.z * c_cam.z), 0.
);
mat3 A = J * mat3(modelViewMatrix);
mat3 cov_proj = A * cov3d * transpose(A);
float diag1 = cov_proj[0][0] + 0.3;
float offDiag = cov_proj[0][1];
float diag2 = cov_proj[1][1] + 0.3;
// Eigendecomposition. This can mostly be derived from characteristic equation, etc.
float mid = 0.5 * (diag1 + diag2);
float radius = length(vec2((diag1 - diag2) / 2.0, offDiag));
float lambda1 = mid + radius;
float lambda2 = max(mid - radius, 0.1);
vec2 diagonalVector = normalize(vec2(offDiag, lambda1 - diag1));
vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector;
vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x);
vRgb = rgb;
vOpacity = opacity;
vPosition = position;
gl_Position = vec4(
vec2(pos2d) / pos2d.w
+ position.x * v1 / viewport * 2.0
+ position.y * v2 / viewport * 2.0, pos2d.z / pos2d.w, 1.);
}
`;

export const fragmentShaderSource = `
precision mediump float;
varying vec3 vRgb;
varying float vOpacity;
varying vec2 vPosition;
uniform vec2 viewport;
uniform vec2 focal;
void main () {
float A = -dot(vPosition, vPosition);
if (A < -4.0) discard;
float B = exp(A) * vOpacity;
gl_FragColor = vec4(vRgb.rgb, B);
}
`;
77 changes: 77 additions & 0 deletions src/viser/client/src/Splatting/SplatSortWorker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/** Worker for sorting splats.
*/

import MakeSorterModulePromise from "./WasmSorter/Sorter.mjs";

export type GaussianBuffersSplitCov = {
// (N, 3)
centers: Float32Array;
// (N, 3)
rgbs: Float32Array;
// (N, 1)
opacities: Float32Array;
// (N, 3)
covA: Float32Array;
// (N, 3)
covB: Float32Array;
};

{
let sorter: any = null;
let viewProj: number[] | null = null;
let sortRunning = false;
const throttledSort = () => {
if (sorter === null || viewProj === null || sortRunning) return;

sortRunning = true;
const lastView = viewProj;
sorter.sort(viewProj[2], viewProj[6], viewProj[10]);
self.postMessage({
centers: sorter.getSortedCenters(),
rgbs: sorter.getSortedRgbs(),
opacities: sorter.getSortedOpacities(),
covA: sorter.getSortedCovA(),
covB: sorter.getSortedCovB(),
});

setTimeout(() => {
sortRunning = false;
if (lastView !== viewProj) {
throttledSort();
}
}, 0);
};

const SorterModulePromise = MakeSorterModulePromise();

self.onmessage = async (e) => {
const data = e.data as
| {
setBuffers: GaussianBuffersSplitCov;
}
| {
setViewProj: number[];
}
| { close: true };

if ("setBuffers" in data) {
// Instantiate sorter with buffers populated.
const buffers = data.setBuffers as GaussianBuffersSplitCov;
sorter = new (await SorterModulePromise).Sorter(
buffers.centers,
buffers.rgbs,
buffers.opacities,
buffers.covA,
buffers.covB,
);
throttledSort();
} else if ("setViewProj" in data) {
// Update view projection matrix.
viewProj = data.setViewProj;
throttledSort();
} else if ("close" in data) {
// Done!
self.close();
}
};
}
16 changes: 16 additions & 0 deletions src/viser/client/src/Splatting/WasmSorter/Sorter.mjs

Large diffs are not rendered by default.

Binary file not shown.
3 changes: 3 additions & 0 deletions src/viser/client/src/Splatting/WasmSorter/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env bash

emcc --bind -Oz sorter.cpp -o Sorter.mjs -s WASM=1 -s NO_EXIT_RUNTIME=1 -s "EXPORTED_RUNTIME_METHODS=['addOnPostRun']" -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=1GB;
Loading

0 comments on commit 1fe3fb5

Please sign in to comment.