diff --git a/.prettierignore b/.prettierignore new file mode 100644 index 000000000..527e4bb61 --- /dev/null +++ b/.prettierignore @@ -0,0 +1,2 @@ +*.mjs +build/ diff --git a/docs/source/_templates/sidebar/brand.html b/docs/source/_templates/sidebar/brand.html index dc6158462..a5f764324 100644 --- a/docs/source/_templates/sidebar/brand.html +++ b/docs/source/_templates/sidebar/brand.html @@ -1,68 +1,95 @@ - - {%- endif %} {%- if theme_light_logo and theme_dark_logo %} - - {%- endif %} - + + {%- endif %} {%- if theme_light_logo and theme_dark_logo %} + + {%- endif %} + - {% endblock brand_content %} + {% endblock brand_content %} -
- -
- Version: {{ version }} -
-
- + + htmlString += ""; + document.getElementById("viser-version-dropdown").innerHTML = htmlString; + } + +
+ + Version: {{ version }} + +
+
+
- - - Github - + + + Github +
diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 4e1b25059..c602c3e48 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -843,9 +843,11 @@ def add_gaussian_splats( assert opacities.shape == (num_gaussians, 1) assert covariances.shape == (num_gaussians, 3, 3) - # Make our covariances more compact! - covariances_triu = ( - covariances.reshape((-1, 9))[:, onp.array([0, 1, 2, 4, 5, 8])] + # Get cholesky factor of covariance. + cov_cholesky_triu = ( + onp.linalg.cholesky(covariances) + .swapaxes(-1, -2) # tril => triu + .reshape((-1, 9))[:, onp.array([0, 1, 2, 4, 5, 8])] .astype(onp.float32) .copy() ) @@ -855,7 +857,7 @@ def add_gaussian_splats( centers.astype(onp.float32).view(onp.uint8), onp.zeros((num_gaussians, 4), dtype=onp.uint8), # Second texelFetch. - covariances_triu.astype(onp.float16).view(onp.uint8), + cov_cholesky_triu.astype(onp.float16).view(onp.uint8), _colors_to_uint8(rgbs), _colors_to_uint8(opacities), ], diff --git a/src/viser/client/src/App.tsx b/src/viser/client/src/App.tsx index 539dfb9a0..1706502e4 100644 --- a/src/viser/client/src/App.tsx +++ b/src/viser/client/src/App.tsx @@ -48,6 +48,7 @@ import { useDisclosure } from "@mantine/hooks"; import { rayToViserCoords } from "./WorldTransformUtils"; import { clickToNDC, clickToOpenCV, isClickValid } from "./ClickUtils"; import { theme } from "./AppTheme"; +import { GaussianSplatsContext } from "./Splatting/GaussianSplats"; export type ViewerContextContents = { // Zustand hooks. @@ -213,9 +214,13 @@ function ViewerContents() { })} > - - - + + + + + {viewer.useGui((state) => state.theme.show_logo) ? ( ) : null} diff --git a/src/viser/client/src/Splatting/GaussianSplats.tsx b/src/viser/client/src/Splatting/GaussianSplats.tsx index 655c18f35..03a60ede2 100644 --- a/src/viser/client/src/Splatting/GaussianSplats.tsx +++ b/src/viser/client/src/Splatting/GaussianSplats.tsx @@ -4,6 +4,12 @@ import SplatSortWorker from "./SplatSortWorker?worker"; import { useFrame, useThree } from "@react-three/fiber"; import { shaderMaterial } from "@react-three/drei"; +export const GaussianSplatsContext = + React.createContext void)[]; + }>>(null); + export type GaussianBuffers = { buffer: Uint32Array; }; @@ -23,6 +29,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( transitionInState: 0.0, }, `precision highp usampler2D; // Most important: ints must be 32-bit. + precision mediump float; // Index from the splat sorter. attribute uint sortedIndex; @@ -58,6 +65,9 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( } void main () { + // Any early return will discard the fragment. + gl_Position = vec4(0.0, 0.0, 2000.0, 1.0); + // Get position + scale from float buffer. ivec2 texSize = textureSize(bufferTexture, 0); ivec2 texPos0 = ivec2((sortedIndex * 2u) % uint(texSize.x), (sortedIndex * 2u) / uint(texSize.x)); @@ -66,16 +76,12 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( // Get center wrt camera. modelViewMatrix is T_cam_world. vec4 c_cam = modelViewMatrix * vec4(center, 1); - if (-c_cam.z < near || -c_cam.z > far) { - gl_Position = vec4(1000.0, 0.0, 0.0, 1.0); + if (-c_cam.z < near || -c_cam.z > far) return; - } vec4 pos2d = projectionMatrix * c_cam; float clip = 1.1 * pos2d.w; - if (pos2d.x < -clip || pos2d.x > clip || pos2d.y < -clip || pos2d.y > clip) { - gl_Position = vec4(1000.0, 0.0, 0.0, 1.0); + if (pos2d.x < -clip || pos2d.x > clip || pos2d.y < -clip || pos2d.y > clip) return; - } vec4 c_camstable = sortSynchronizedModelViewMatrix * vec4(center, 1); vec4 stablePos2d = projectionMatrix * c_camstable; @@ -87,16 +93,17 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( ivec2 texPos1 = ivec2((sortedIndex * 2u + 1u) % uint(texSize.x), (sortedIndex * 2u + 1u) / uint(texSize.x)); uvec4 intBufferData = texelFetch(bufferTexture, texPos1, 0); uint rgbaUint32 = intBufferData.w; - vec2 cov01 = unpackHalf2x16(intBufferData.x) * cov_scale; - vec2 cov23 = unpackHalf2x16(intBufferData.y) * cov_scale; - vec2 cov45 = unpackHalf2x16(intBufferData.z) * cov_scale; + vec2 chol01 = unpackHalf2x16(intBufferData.x); + vec2 chol23 = unpackHalf2x16(intBufferData.y); + vec2 chol45 = unpackHalf2x16(intBufferData.z); // Do the actual splatting. - mat3 cov3d = mat3( - cov01.x, cov01.y, cov23.x, - cov01.y, cov23.y, cov45.x, - cov23.x, cov45.x, cov45.y + mat3 chol = mat3( + chol01.x, chol01.y, chol23.x, + 0., chol23.y, chol45.x, + 0., 0., chol45.y ); + mat3 cov3d = chol * transpose(chol) * cov_scale; mat3 J = mat3( // Matrices are column-major. focal.x / c_cam.z, 0., 0.0, @@ -114,10 +121,8 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( float radius = length(vec2((diag1 - diag2) / 2.0, offDiag)); float lambda1 = mid + radius; float lambda2 = mid - radius; - if (lambda2 < 0.0) { - gl_Position = vec4(1000.0, 0.0, 0.0, 1.0); + if (lambda2 < 0.0) return; - } vec2 diagonalVector = normalize(vec2(offDiag, lambda1 - diag1)); vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector; vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x); @@ -131,14 +136,10 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( // Throw the Gaussian off the screen if it's too close, too far, or too small. float weightedDeterminant = vRgba.a * (diag1 * diag2 - offDiag * offDiag); - if (weightedDeterminant < 0.1) { - gl_Position = vec4(1000.0, 0.0, 0.0, 1.0); + if (weightedDeterminant < 0.1) return; - } - if (weightedDeterminant < 1.0 && hash3D(center) < weightedDeterminant) { // This is not principled. It just makes things faster. - gl_Position = vec4(1000.0, 0.0, 0.0, 1.0); + if (weightedDeterminant < 1.0 && hash3D(center) < weightedDeterminant) // This is not principled. It just makes things faster. return; - } vPosition = position.xy; gl_Position = vec4( @@ -147,7 +148,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( + position.y * v2 / viewport * 2.0, stablePos2d.z / stablePos2d.w, 1.); } `, - `precision highp float; + `precision mediump float; uniform vec2 viewport; uniform vec2 focal; @@ -159,7 +160,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( float A = -dot(vPosition, vPosition); if (A < -4.0) discard; float B = exp(A) * vRgba.a; - if ( B < 0.02 ) discard; // alphaTest. + if (B < 0.03) discard; // alphaTest. gl_FragColor = vec4(vRgba.rgb, B); }`, ); @@ -177,6 +178,8 @@ export default function GaussianSplats({ const initializedTextures = React.useRef(false); const [sortSynchronizedModelViewMatrix] = React.useState(new THREE.Matrix4()); + const splatContext = React.useContext(GaussianSplatsContext)!.current; + // We'll use the vanilla three.js API, which for our use case is more // flexible than the declarative version (particularly for operations like // dynamic updates to buffers and shader uniforms). @@ -238,27 +241,44 @@ export default function GaussianSplats({ const sortWorker = new SplatSortWorker(); sortWorker.onmessage = (e) => { sortedIndexAttribute.set(e.data.sortedIndices as Int32Array); - sortedIndexAttribute.needsUpdate = true; - material.uniforms.sortSynchronizedModelViewMatrix.value.copy( - sortSynchronizedModelViewMatrix, - ); - // A simple but reasonably effective heuristic for render ordering. - // - // To minimize artifacts: - // - When there are multiple splat objects, we want to render the closest - // ones *last*. This improves the likelihood of correct alpha - // compositing and reduces reliance on alpha testing. - // - We generally want to render other objects like meshes *before* - // Gaussians. They're usually opaque. - meshRef.current!.renderOrder = (-e.data.minDepth as number) + 1000.0; - - isSortingRef.current = false; - - // Trigger initial render. - if (!initializedTextures.current) { - material.uniforms.numGaussians.value = numGaussians; - bufferTexture.needsUpdate = true; - initializedTextures.current = true; + const synchronizedSortUpdateCallback = () => { + isSortingRef.current = false; + + // Wait for onmessage to be triggered for all Gaussians. + sortedIndexAttribute.needsUpdate = true; + material.uniforms.sortSynchronizedModelViewMatrix.value.copy( + sortSynchronizedModelViewMatrix, + ); + // A simple but reasonably effective heuristic for render ordering. + // + // To minimize artifacts: + // - When there are multiple splat objects, we want to render the closest + // ones *last*. This improves the likelihood of correct alpha + // compositing and reduces reliance on alpha testing. + // - We generally want to render other objects like meshes *before* + // Gaussians. They're usually opaque. + meshRef.current!.renderOrder = (-e.data.minDepth as number) + 1000.0; + + // Trigger initial render. + if (!initializedTextures.current) { + material.uniforms.numGaussians.value = numGaussians; + bufferTexture.needsUpdate = true; + initializedTextures.current = true; + } + }; + + // Synchronize sort updates across multiple Gaussian splats. This + // prevents z-fighting. + splatContext.numSorting -= 1; + if (splatContext.numSorting === 0) { + synchronizedSortUpdateCallback(); + console.log(splatContext.sortUpdateCallbacks.length); + for (const callback of splatContext.sortUpdateCallbacks) { + callback(); + } + splatContext.sortUpdateCallbacks.length = 0; + } else { + splatContext.sortUpdateCallbacks.push(synchronizedSortUpdateCallback); } }; sortWorker.postMessage({ @@ -301,7 +321,7 @@ export default function GaussianSplats({ const uniforms = material.uniforms; uniforms.transitionInState.value = Math.min( - uniforms.transitionInState.value + delta, + uniforms.transitionInState.value + delta * 2.0, 1.0, ); uniforms.focal.value = [fx, fy]; @@ -322,6 +342,7 @@ export default function GaussianSplats({ ) { sortSynchronizedModelViewMatrix.copy(T_camera_obj); sortWorker.postMessage({ setT_camera_obj: T_camera_obj.elements }); + splatContext.numSorting += 1; isSortingRef.current = true; prevT_camera_obj.copy(T_camera_obj); } diff --git a/src/viser/client/src/WebsocketInterface.tsx b/src/viser/client/src/WebsocketInterface.tsx index 173cc207b..34d0b898d 100644 --- a/src/viser/client/src/WebsocketInterface.tsx +++ b/src/viser/client/src/WebsocketInterface.tsx @@ -240,20 +240,16 @@ function useMessageHandler() { message.plane == "xz" ? new THREE.Euler(0.0, 0.0, 0.0) : message.plane == "xy" - ? new THREE.Euler(Math.PI / 2.0, 0.0, 0.0) - : message.plane == "yx" - ? new THREE.Euler(0.0, Math.PI / 2.0, Math.PI / 2.0) - : message.plane == "yz" - ? new THREE.Euler(0.0, 0.0, Math.PI / 2.0) - : message.plane == "zx" - ? new THREE.Euler(0.0, Math.PI / 2.0, 0.0) - : message.plane == "zy" - ? new THREE.Euler( - -Math.PI / 2.0, - 0.0, - -Math.PI / 2.0, - ) - : undefined + ? new THREE.Euler(Math.PI / 2.0, 0.0, 0.0) + : message.plane == "yx" + ? new THREE.Euler(0.0, Math.PI / 2.0, Math.PI / 2.0) + : message.plane == "yz" + ? new THREE.Euler(0.0, 0.0, Math.PI / 2.0) + : message.plane == "zx" + ? new THREE.Euler(0.0, Math.PI / 2.0, 0.0) + : message.plane == "zy" + ? new THREE.Euler(-Math.PI / 2.0, 0.0, -Math.PI / 2.0) + : undefined } /> @@ -340,16 +336,16 @@ function useMessageHandler() { message.material == "standard" || message.wireframe ? new THREE.MeshStandardMaterial(standardArgs) : message.material == "toon3" - ? new THREE.MeshToonMaterial({ - gradientMap: generateGradientMap(3), - ...standardArgs, - }) - : message.material == "toon5" - ? new THREE.MeshToonMaterial({ - gradientMap: generateGradientMap(5), - ...standardArgs, - }) - : assertUnreachable(message.material); + ? new THREE.MeshToonMaterial({ + gradientMap: generateGradientMap(3), + ...standardArgs, + }) + : message.material == "toon5" + ? new THREE.MeshToonMaterial({ + gradientMap: generateGradientMap(5), + ...standardArgs, + }) + : assertUnreachable(message.material); geometry.setAttribute( "position", new THREE.Float32BufferAttribute(