-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
528 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import React, { useEffect } from "react"; | ||
import * as THREE from "three"; | ||
import SplatSortWorker from "./SplatSortWorker?worker"; | ||
import { fragmentShaderSource, vertexShaderSource } from "./Shaders"; | ||
import { useFrame } from "@react-three/fiber"; | ||
import { GaussianBuffersSplitCov } from "./SplatSortWorker"; | ||
|
||
export type GaussianBuffers = { | ||
// (N, 3) | ||
centers: Float32Array; | ||
// (N, 3) | ||
rgbs: Float32Array; | ||
// (N, 1) | ||
opacities: Float32Array; | ||
// (N, 6) | ||
covariancesTriu: Float32Array; | ||
}; | ||
|
||
export default function GaussianSplats({ | ||
buffers, | ||
}: { | ||
buffers: GaussianBuffers; | ||
}) { | ||
// Create buffer geometry + setter function. | ||
const [geometry, setSortedBuffers] = React.useMemo(() => { | ||
const geometry = new THREE.InstancedBufferGeometry(); | ||
const numGaussians = buffers.centers.length / 3; | ||
geometry.instanceCount = numGaussians; | ||
|
||
// Each Gaussian will be drawn as a quadrilateral. | ||
geometry.setIndex( | ||
new THREE.BufferAttribute(new Uint32Array([0, 2, 1, 0, 3, 2]), 1), | ||
); | ||
geometry.setAttribute( | ||
"position", | ||
new THREE.BufferAttribute( | ||
new Float32Array([-2, -2, 2, -2, 2, 2, -2, 2]), | ||
2, | ||
), | ||
); | ||
|
||
// Create attributes. | ||
const centerAttribute = new THREE.InstancedBufferAttribute( | ||
new Float32Array(numGaussians * 3), | ||
3, | ||
); | ||
const rgbAttribute = new THREE.InstancedBufferAttribute( | ||
new Float32Array(numGaussians * 3), | ||
3, | ||
); | ||
const opacityAttribute = new THREE.InstancedBufferAttribute( | ||
new Float32Array(numGaussians), | ||
1, | ||
); | ||
const covAAttribute = new THREE.InstancedBufferAttribute( | ||
new Float32Array(numGaussians * 3), | ||
3, | ||
); | ||
const covBAttribute = new THREE.InstancedBufferAttribute( | ||
new Float32Array(numGaussians * 3), | ||
3, | ||
); | ||
|
||
geometry.setAttribute("center", centerAttribute); | ||
geometry.setAttribute("rgb", rgbAttribute); | ||
geometry.setAttribute("opacity", opacityAttribute); | ||
geometry.setAttribute("covA", covAAttribute); | ||
geometry.setAttribute("covB", covBAttribute); | ||
|
||
return [ | ||
geometry, | ||
(sortedBuffers: GaussianBuffersSplitCov) => { | ||
centerAttribute.set(sortedBuffers.centers); | ||
rgbAttribute.set(sortedBuffers.rgbs); | ||
opacityAttribute.set(sortedBuffers.opacities); | ||
covAAttribute.set(sortedBuffers.covA); | ||
covBAttribute.set(sortedBuffers.covB); | ||
|
||
centerAttribute.needsUpdate = true; | ||
rgbAttribute.needsUpdate = true; | ||
opacityAttribute.needsUpdate = true; | ||
covAAttribute.needsUpdate = true; | ||
covBAttribute.needsUpdate = true; | ||
}, | ||
]; | ||
}, []); | ||
|
||
// Update shader uniforms. | ||
const shaderMaterial = React.useMemo(() => { | ||
console.log("making material"); | ||
return new THREE.RawShaderMaterial({ | ||
fragmentShader: fragmentShaderSource, | ||
vertexShader: vertexShaderSource, | ||
uniforms: { | ||
viewport: { value: null }, | ||
focal: { value: null }, | ||
}, | ||
depthTest: true, | ||
depthWrite: false, | ||
transparent: true, | ||
}); | ||
}, []); | ||
React.useEffect(() => { | ||
return () => shaderMaterial.dispose(); | ||
}, []); | ||
|
||
useFrame((state) => { | ||
const dpr = state.viewport.dpr; | ||
const fovY = | ||
((state.camera as THREE.PerspectiveCamera).fov * Math.PI) / 180.0; | ||
const fovX = 2 * Math.atan(Math.tan(fovY / 2) * state.viewport.aspect); | ||
const fy = (dpr * state.size.height) / (2 * Math.tan(fovY / 2)); | ||
const fx = (dpr * state.size.width) / (2 * Math.tan(fovX / 2)); | ||
|
||
shaderMaterial.uniforms.focal.value = [fx, fy]; | ||
shaderMaterial.uniforms.viewport.value = [ | ||
state.size.width * dpr, | ||
state.size.height * dpr, | ||
]; | ||
}); | ||
|
||
// Create worker for sorting Gaussians. | ||
const splatSortWorkerRef = React.useRef<Worker | null>(null); | ||
useEffect(() => { | ||
const sortWorker = new SplatSortWorker(); | ||
sortWorker.postMessage({ | ||
setBuffers: splitCovariances(buffers), | ||
}); | ||
splatSortWorkerRef.current = sortWorker; | ||
|
||
sortWorker.onmessage = (e) => { | ||
setSortedBuffers(e.data as GaussianBuffersSplitCov); | ||
}; | ||
|
||
// Close the worker when done. | ||
return () => sortWorker.postMessage({ close: true }); | ||
}, [buffers]); | ||
|
||
// Synchronize view projection matrix with sort worker. | ||
const meshRef = React.useRef<THREE.Mesh>(null); | ||
const prevViewProj = React.useRef<THREE.Matrix4>(); | ||
useFrame((state) => { | ||
const mesh = meshRef.current; | ||
const sortWorker = splatSortWorkerRef.current; | ||
if (mesh === null || sortWorker === null) return; | ||
|
||
// Compute view projection matrix. | ||
const viewProj = new THREE.Matrix4() | ||
.multiply(state.camera.projectionMatrix) | ||
.multiply(state.camera.matrixWorldInverse) | ||
.multiply(mesh.matrixWorld); | ||
|
||
// If changed, use projection matrix to sort Gaussians. | ||
if ( | ||
prevViewProj.current === undefined || | ||
!viewProj.equals(prevViewProj.current) | ||
) { | ||
sortWorker.postMessage({ setViewProj: viewProj.elements }); | ||
prevViewProj.current = viewProj; | ||
} | ||
}); | ||
|
||
return <mesh ref={meshRef} geometry={geometry} material={shaderMaterial} />; | ||
} | ||
|
||
/** Split upper-triangular terms (6D) of covariance into pair of 3D terms. This | ||
* lets us pass vec3 arrays into our shader. */ | ||
function splitCovariances(buffers: GaussianBuffers) { | ||
const covA = new Float32Array(buffers.covariancesTriu.length / 2); | ||
const covB = new Float32Array(buffers.covariancesTriu.length / 2); | ||
for (let i = 0; i < covA.length; i++) { | ||
covA[i] = buffers.covariancesTriu[Math.floor(i / 3) * 6 + (i % 3)]; | ||
covB[i] = buffers.covariancesTriu[Math.floor(i / 3) * 6 + (i % 3) + 3]; | ||
} | ||
return { | ||
centers: buffers.centers, | ||
rgbs: buffers.rgbs, | ||
opacities: buffers.opacities, | ||
covA: covA, | ||
covB: covB, | ||
}; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/** Shaders for Gaussian splatting. | ||
* | ||
* These are adapted from Kevin Kwok, with only minor modifications. | ||
* | ||
* https://github.com/antimatter15/splat/ | ||
*/ | ||
|
||
export const vertexShaderSource = ` | ||
precision mediump float; | ||
attribute vec2 position; | ||
attribute vec3 rgb; | ||
attribute float opacity; | ||
attribute vec3 center; | ||
attribute vec3 covA; | ||
attribute vec3 covB; | ||
uniform mat4 projectionMatrix, modelViewMatrix; | ||
uniform vec2 focal; | ||
uniform vec2 viewport; | ||
varying vec3 vRgb; | ||
varying float vOpacity; | ||
varying vec2 vPosition; | ||
mat3 transpose(mat3 m) { | ||
return mat3( | ||
m[0][0], m[1][0], m[2][0], | ||
m[0][1], m[1][1], m[2][1], | ||
m[0][2], m[1][2], m[2][2] | ||
); | ||
} | ||
void main () { | ||
// Get center wrt camera. modelViewMatrix is T_cam_world. | ||
vec4 c_cam = modelViewMatrix * vec4(center, 1); | ||
vec4 pos2d = projectionMatrix * c_cam; | ||
// Splat covariance. | ||
mat3 cov3d = mat3( | ||
covA.x, covA.y, covA.z, | ||
covA.y, covB.x, covB.y, | ||
covA.z, covB.y, covB.z | ||
); | ||
mat3 J = mat3( | ||
// Note that matrices are column-major. | ||
focal.x / c_cam.z, 0., 0.0, | ||
0., focal.y / c_cam.z, 0.0, | ||
-(focal.x * c_cam.x) / (c_cam.z * c_cam.z), -(focal.y * c_cam.y) / (c_cam.z * c_cam.z), 0. | ||
); | ||
mat3 A = J * mat3(modelViewMatrix); | ||
mat3 cov_proj = A * cov3d * transpose(A); | ||
float diag1 = cov_proj[0][0] + 0.3; | ||
float offDiag = cov_proj[0][1]; | ||
float diag2 = cov_proj[1][1] + 0.3; | ||
// Eigendecomposition. This can mostly be derived from characteristic equation, etc. | ||
float mid = 0.5 * (diag1 + diag2); | ||
float radius = length(vec2((diag1 - diag2) / 2.0, offDiag)); | ||
float lambda1 = mid + radius; | ||
float lambda2 = max(mid - radius, 0.1); | ||
vec2 diagonalVector = normalize(vec2(offDiag, lambda1 - diag1)); | ||
vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector; | ||
vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x); | ||
vRgb = rgb; | ||
vOpacity = opacity; | ||
vPosition = position; | ||
gl_Position = vec4( | ||
vec2(pos2d) / pos2d.w | ||
+ position.x * v1 / viewport * 2.0 | ||
+ position.y * v2 / viewport * 2.0, pos2d.z / pos2d.w, 1.); | ||
} | ||
`; | ||
|
||
export const fragmentShaderSource = ` | ||
precision mediump float; | ||
varying vec3 vRgb; | ||
varying float vOpacity; | ||
varying vec2 vPosition; | ||
uniform vec2 viewport; | ||
uniform vec2 focal; | ||
void main () { | ||
float A = -dot(vPosition, vPosition); | ||
if (A < -4.0) discard; | ||
float B = exp(A) * vOpacity; | ||
gl_FragColor = vec4(vRgb.rgb, B); | ||
} | ||
`; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
/** Worker for sorting splats. | ||
*/ | ||
|
||
import MakeSorterModulePromise from "./WasmSorter/Sorter.mjs"; | ||
|
||
export type GaussianBuffersSplitCov = { | ||
// (N, 3) | ||
centers: Float32Array; | ||
// (N, 3) | ||
rgbs: Float32Array; | ||
// (N, 1) | ||
opacities: Float32Array; | ||
// (N, 3) | ||
covA: Float32Array; | ||
// (N, 3) | ||
covB: Float32Array; | ||
}; | ||
|
||
{ | ||
let sorter: any = null; | ||
let viewProj: number[] | null = null; | ||
let sortRunning = false; | ||
const throttledSort = () => { | ||
if (sorter === null || viewProj === null || sortRunning) return; | ||
|
||
sortRunning = true; | ||
const lastView = viewProj; | ||
sorter.sort(viewProj[2], viewProj[6], viewProj[10]); | ||
self.postMessage({ | ||
centers: sorter.getSortedCenters(), | ||
rgbs: sorter.getSortedRgbs(), | ||
opacities: sorter.getSortedOpacities(), | ||
covA: sorter.getSortedCovA(), | ||
covB: sorter.getSortedCovB(), | ||
}); | ||
|
||
setTimeout(() => { | ||
sortRunning = false; | ||
if (lastView !== viewProj) { | ||
throttledSort(); | ||
} | ||
}, 0); | ||
}; | ||
|
||
const SorterModulePromise = MakeSorterModulePromise(); | ||
|
||
self.onmessage = async (e) => { | ||
const data = e.data as | ||
| { | ||
setBuffers: GaussianBuffersSplitCov; | ||
} | ||
| { | ||
setViewProj: number[]; | ||
} | ||
| { close: true }; | ||
|
||
if ("setBuffers" in data) { | ||
// Instantiate sorter with buffers populated. | ||
const buffers = data.setBuffers as GaussianBuffersSplitCov; | ||
sorter = new (await SorterModulePromise).Sorter( | ||
buffers.centers, | ||
buffers.rgbs, | ||
buffers.opacities, | ||
buffers.covA, | ||
buffers.covB, | ||
); | ||
throttledSort(); | ||
} else if ("setViewProj" in data) { | ||
// Update view projection matrix. | ||
viewProj = data.setViewProj; | ||
throttledSort(); | ||
} else if ("close" in data) { | ||
// Done! | ||
self.close(); | ||
} | ||
}; | ||
} |
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/usr/bin/env bash | ||
|
||
emcc --bind -Oz sorter.cpp -o Sorter.mjs -s WASM=1 -s NO_EXIT_RUNTIME=1 -s "EXPORTED_RUNTIME_METHODS=['addOnPostRun']" -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=1GB; |
Oops, something went wrong.