diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..f959f82 --- /dev/null +++ b/.clang-format @@ -0,0 +1,24 @@ +--- +BasedOnStyle: LLVM +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +NamespaceIndentation: None +SpacesInSquareBrackets: 'false' +Standard: Auto +FixNamespaceComments: 'true' +BreakBeforeBraces: Allman +SortIncludes: 'true' +SpaceBeforeParens: Never +AllowShortFunctionsOnASingleLine: Inline +BreakConstructorInitializers: BeforeColon +IndentWidth: 4 +PointerAlignment: Left +SpaceInEmptyParentheses: 'false' +SpacesBeforeTrailingComments: 1 +SpacesInAngles: 'false' +SpacesInParentheses: 'false' +AlwaysBreakTemplateDeclarations: 'true' +BreakBeforeInheritanceComma: 'false' +ColumnLimit: 120 +IndentCaseLabels: 'true' +... diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..834a9cc --- /dev/null +++ b/.editorconfig @@ -0,0 +1,12 @@ +# This is a standard to preconfigure editors +# check: https://editorconfig.org/ +root = true + +# 4 space indentation +[*.py] +charset = utf-8 +indent_style = space +indent_size = 4 +trim_trailing_whitespace = true +insert_final_newline = false +end_of_line = lf diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..2a951b0 --- /dev/null +++ b/.flake8 @@ -0,0 +1,12 @@ +[flake8] +ignore= + # line too long - is to be checked by black + E501, + # E203 (spaces around :) + E203, + # and W503 (line break before binary operator) are output as a result of Black formatting + W503 + +dictionaries=en_US,python,technical +docstring-convention=google +spellcheck-targets=comments diff --git a/.github/ISSUE_TEMPLATE/01_feature_request.yml b/.github/ISSUE_TEMPLATE/01_feature_request.yml new file mode 100644 index 0000000..e56ec68 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/01_feature_request.yml @@ -0,0 +1,33 @@ +name: 🚀 Feature Request +description: Submit a proposal/request for a new feature. +title: "feat: " +labels: ["feature"] +body: + - type: markdown + attributes: + value: | + Thanks for contributing! + - type: textarea + id: feature-description + attributes: + label: Feature description + description: A clear and concise description of the feature proposal + placeholder: Tell us what you want! + validations: + required: true + - type: textarea + id: feature-motivation + attributes: + label: Motivation + description: A clear and concise description of what the problem is, e.g., I'm always frustrated when [...] + placeholder: Why do you need this feature? + validations: + required: true + - type: textarea + id: feature-context + attributes: + label: Additional context + description: Add any other context or screenshots about the feature request here. + placeholder: Screenshots, code snippets, etc. + + diff --git a/.github/ISSUE_TEMPLATE/02_bug_report.yml b/.github/ISSUE_TEMPLATE/02_bug_report.yml new file mode 100644 index 0000000..8c43f4e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/02_bug_report.yml @@ -0,0 +1,31 @@ +name: 🐞 Bug Report +description: File a bug report +title: "bug: " +labels: ["bug"] +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this bug report! + - type: textarea + id: what-happened + attributes: + label: What happened? + description: Also tell us, what did you expect to happen? + placeholder: Tell us what you see! + value: "A bug happened!" + validations: + required: true + - type: textarea + id : how-to-reproduce + attributes: + label: How can we reproduce it? + description: Please provide a code snippet to reproduce the bug. + placeholder: import dbally + render: python + - type: textarea + id: logs + attributes: + label: Relevant log output + description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. + render: shell \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..3d179b8 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,55 @@ +name: Semantic Release + +on: + workflow_dispatch: + inputs: + releaseType: + description: "version update type" + required: true + type: choice + default: "automatic" + options: + - "automatic" + - "major" + - "minor" + - "patch" + +jobs: + release: + runs-on: ubuntu-latest + concurrency: release + permissions: + id-token: write + contents: write + + environment: + name: pypi + url: https://test.pypi.org/project/ds-splat + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Python Semantic Release Manual + id: release_manual + if: ${{ github.event.inputs.releaseType != 'automatic' }} + uses: python-semantic-release/python-semantic-release@master + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + force: ${{ github.event.inputs.releaseType }} + changelog: false + + - name: Python Semantic Release Automatic + id: release_automatic + if: ${{ github.event.inputs.releaseType == 'automatic' }} + uses: python-semantic-release/python-semantic-release@master + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + changelog: false + + + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + if: steps.release_manual.outputs.released == 'true' || steps.release_automatic.outputs.released == 'true' diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..969cabb --- /dev/null +++ b/.gitignore @@ -0,0 +1,94 @@ +# Directories +.vscode/ +.idea/ +.neptune/ +.pytest_cache/ +.mypy_cache/ +venv/ +__pycache__/ +**.egg-info/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# Sphinx documentation +docs/_build/ +public/ +# autogenerated package license table +docs/licenses_table.rst + +# license dump file +licenses.txt + +# File formats +*.onnx +*.pyc +*.pt +*.pth +*.pkl +*.mar +*.torchscript +**/.ipynb_checkpoints +**/dist/ +**/checkpoints/ +**/outputs/ + +# Other env files +.python-version +pyvenv.cfg +pip-selfcheck.json + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover +.hypothesis/ + +# dotenv +.env +src/dbally_benchmark/.env + +# coverage and pytest reports +coverage.xml +report.xml + +# CMake +cmake-build-*/ + +# Terraform +**/.terraform.lock.hcl +**/.terraform + +# experiments results +experiments/ + +# mkdocs generated files +site/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..541a8fb --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "cpp/gsplat/third_party/glm"] + path = cpp/gsplat/third_party/glm + url = https://github.com/g-truc/glm.git diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..f0ec85e --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,42 @@ +cmake_minimum_required(VERSION 3.18) +project(cuda_rasterizer LANGUAGES CXX CUDA) + +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) +option(WITH_TORCH "Builds also rasterizer compatible with torch tensors" ON) + +find_package(Thrust REQUIRED) + +add_library(ds_cuda_rasterizer + cpp/rasterizer_cuda.cu + cpp/rasterizer_kernels.cu) + +target_include_directories(ds_cuda_rasterizer PRIVATE ${TORCH_INCLUDE_DIRS}) +target_include_directories(ds_cuda_rasterizer PRIVATE ${CUDA_INCLUDE_DIRS}) +target_include_directories(ds_cuda_rasterizer PUBLIC cpp) + +target_link_libraries(ds_cuda_rasterizer ${CUDA_LIBRARIES} thrust::thrust) + +set_target_properties(ds_cuda_rasterizer PROPERTIES + CUDA_ARCHITECTURES "native") + +if(WITH_TORCH) + find_package(Torch REQUIRED) + find_package(Python3 COMPONENTS Interpreter Development REQUIRED) + + target_sources(ds_cuda_rasterizer PRIVATE + cpp/rasterizer_torch.cu + cpp/gsplat/backward.cu + cpp/gsplat/bindings.cu) + + target_link_libraries(ds_cuda_rasterizer ${TORCH_LIBRARIES} Python3::Python) + + if(BUILD_TESTING AND WITH_TORCH) + enable_testing() + find_package(GTest REQUIRED) + include(CTest) + + add_executable(rasterizer_test tests/rasterizer_test.cpp) + target_link_libraries(rasterizer_test ds_cuda_rasterizer gtest::gtest ${TORCH_LIBRARIES} ${CUDA_LIBRARIES}) + endif() +endif() diff --git a/LICENSE b/LICENSE index 830b3e6..7552000 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,21 @@ -MIT License +This project adapts and modifies code for backward CUDA kernels from Nerfstudio's project gsplat (modified) -Copyright (c) 2024 deepsense.ai +https://github.com/nerfstudio-project/gsplat -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: +Nerfstudio gsplat is available under the Apache 2.0 License, which you can find at http://www.apache.org/licenses/LICENSE-2.0 -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. +The rest of the project, including forward CUDA calls, is the original work of the deepsense.ai team. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +Copyright (c) 2024 deepsense.ai sp. z o.o. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..5f4d1cb --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,4 @@ +recursive-include cpp * +include cpp/*.cpp +include cpp/*.cu +include cpp/gsplat/*.cu diff --git a/README.md b/README.md index b5969aa..5184ee1 100644 --- a/README.md +++ b/README.md @@ -1 +1,215 @@ -# ds-splat \ No newline at end of file +# 🖌️ deepsense.ai 3D Gaussian Splatting + +Fastest open-source implementation (Apache 2.0 License) of 3D Gaussian Splatting rasterizer function as forward/backward +CUDA kernels. Forward call is our original work and our backward code is based on +[nerfstudio's gsplat](https://github.com/nerfstudio-project/gsplat) implementation. +We are using the same api as Vanilla [graphdeco-inria 3D Gaussian Splatting implementation](https://github.com/graphdeco-inria/gaussian-splatting), +so it is very easy to replace original render calls simply by swapping the import. + + +![Training Process](assets/training_640_5fps.gif) + +## Table of Contents +- [⚡ Get fastest open-source forward/backward kernels](#-get-fastest-open-source-forwardbackward-kernels) + - [📦 Get From PyPI](#-get-from-pypi) + - [💡 Integrated into Gaussian Splatting Lighting](#-integrated-into-gaussian-splatting-lighting) +- [🔧 Install from repository](#-install-from-repository) + - [🐍 Using Python Extension](#-using-python-extension) + - [🛠 CPP Project Initialization](#-cpp-project-initialization) +- [🔄 How to switch to open-source KNN](#-how-to-switch-to-open-source-knn) +- [📊 Benchmarks](#-benchmarks) + +## ⚡ Get fastest open-source forward/backward kernels + +Fastest open-source and easy to use replacement for these who are using non-commercial friendly Vanilla +[graphdeco-inria 3D Gaussian Splatting implementation](https://github.com/graphdeco-inria/gaussian-splatting). + +- Forward and backward CUDA calls +- Fastest open-source +- Easy to integrate +- Thrust and Torch I/O API + +### 📦 Get From PyPI + +Follow this step, if you are already using Vanilla's [graphdeco-inria 3D Gaussian Splatting implementation](https://github.com/graphdeco-inria/gaussian-splatting) +in your project and you want to replace forward/backward kernels +with deepsense.ai open-source kernels. + +In your environment, simply install: +```bash +pip install ds-splat +``` + +You are good to go just by swapping imports: +```diff +- from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer ++ from ds-splat import GaussianRasterizationSettings, GaussianRasterizer +``` + +After swapping to our code, you will keep 3D Gaussian Splatting functionality (backward and forward passes) and you +will use open-source code. If you also want to use open-source code for the KNN step in preprocessing, scroll down! + +### 💡 Integrated into Gaussian Splatting Lighting + +If you are rather starting project from scratch and are interested in end-to-end environment, we recommend to check +our integration into [gaussian-splatting-lighting](https://github.com/yzslab/gaussian-splatting-lightning) repository. +Gaussian splatting lighting repository is under MIT License, but submodules like Vanilla's forward/backward kernels or +KNN implementation has non-commercial friendly license. You can use deepsense ds-splat as a backend, and this way using +fastest open-source forward/backward kernel calls. + + +## 🔧 Install from repository + +Instead of installing from PyPI, you can install ds-splat package directly from this repository. + +### 🐍 Using Python Extension + +you can use pip install in the project's root directory: +```bash +pip install . +``` + +Via setup.py, this will compile CUDA and CPP code and will install ds-splat package. + +### 🛠 CPP Project Initialization + +This is a bit more manual and you don't have to make it if you installed from PyPI or with the above pip install. + +If you prefer to build project from scratch follow instructions here. + +This project uses conan for additional dependencies i.e. Catch2. To generate CMake project follow these instructions: + +```bash +cd cuda_rasterizer # make sure you are in the root directory +conan install . -of=conan/x86_reldebug --settings=build_type=RelWithDebInfo --profile=default +mkdir build_cpp; cd build_cpp +cmake -DCMAKE_PREFIX_PATH=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DCMAKE_TOOLCHAIN_FILE=../conan/x86_reldebug/build/RelWithDebInfo/generators/conan_toolchain.cmake -DBUILD_TESTING=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo .. +make +``` + +If there are any problems regarding runtime exception (e.g. `std::bad_alloc`) or link errors make sure to edit your conan profile to use specific ABI. +Following conanfile was tested: + + +``` +[settings] +arch=x86_64 +build_type=Release +compiler=gcc +compiler.cppstd=17 +compiler.libcxx=libstdc++ +compiler.version=11 +os=Linux +``` + +## 🔄 How to switch to open-source KNN + +If you are using for e.g. [gaussain splatting lighting](https://github.com/yzslab/gaussian-splatting-lightning) repository, +then forward/backward CUDA kernels and KNN are under [Gaussian-Splatting License](https://github.com/graphdeco-inria/gaussian-splatting?tab=License-1-ov-file#readme). +When you switch to our code following instructions above, you will use our open source forward and backward calls. +Here, we provide instructions on how to also use open source KNN implementation via Faiss. This instruction is for +replacing KNN implementation in [gaussain splatting lighting](https://github.com/yzslab/gaussian-splatting-lightning) +repository. + +#### Install Faiss +https://github.com/facebookresearch/faiss/blob/main/INSTALL.md +For example, if you are using conda, in your environment install: +```bash +conda install -c pytorch -c nvidia -c rapidsai -c conda-forge faiss-gpu-raft=1.8.0 +``` + +#### Modify GaussianModel class + +1. localize gaussian_model.py file that contains class GaussianModel +2. import faiss + ```python + import faiss + ``` +3. add method for averaged distances + ```python + def _get_averaged_distances(self, pcd_points_np: np.ndarray, method: str = "CPU_approx", device_id: int = 0, + k: int = 4, dim: int = 3, nlist: int = 200) -> np.ndarray: + """ + This method takes numpy array of points and returns averaged distances for k-nearest neighbours + for each query point (excluding query point). Database/reference points and query points are same set. + + Using Faiss as a backend. + + + Args: + pcd_points_np: pcd points as numpy array + method: how faiss create indices and what is target device for calc. {"CPU", "GPU", "CPU_approx", + "GPU_approx"} + device_id: GPU device id + k: k-nearest neighbours (including self) + dim: dimentionality of the dataset. 3 by default. + nlist: the number of clusters or cells in the inverted file (IVF) structure when using an IndexIVFFlat + index. Only relevant for approximated methods. + + Returns: + numpy array as mean from k-nearest neighbour (except self) for each query point + """ + valid_index_types = {"CPU", "GPU", "CPU_approx", "GPU_approx"} + pcd_points_float_32 = pcd_points_np.astype(np.float32) + + if method == "CPU": + index = faiss.IndexFlatL2(dim) + elif method == "GPU": + res = faiss.StandardGpuResources() + index = faiss.GpuIndexFlatL2(res, dim) + elif method == "CPU_approx": + quantizer = faiss.IndexFlatL2(3) # the other index + index = faiss.IndexIVFFlat(quantizer, dim, nlist) + elif method == "GPU_approx": + res = faiss.StandardGpuResources() + quantizer = faiss.IndexFlatL2(3) # the other index. Must be CPU as nested GPU indexes are not supported + index = faiss.index_cpu_to_gpu(res, device_id, faiss.IndexIVFFlat(quantizer, dim, nlist)) + else: + raise ValueError(f"Invalid index_type. Expected one of {valid_index_types}, but got {method}.") + + if method in {"CPU_approx", "GPU_approx"}: + index.train(pcd_points_float_32) + index.add(pcd_points_float_32) + + D, _ = index.search(pcd_points_float_32, k) + D_mean = np.mean(D[:, 1:], axis=1) + + return D_mean + ``` +4. localize create_from_pcd(...) method and modify it. + Replace lines: + ```diff + - dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001).to(deivce) + + dist_means_np = self._get_averaged_distances(pcd_points_np=pcd_points_np, method="CPU_approx") + + dist2 = torch.clamp_min(torch.tensor(dist_means_np), 0.0000001).to(deivce) + ``` + +This way you have modified the KNN method. Now it is independent from a +[licensed](https://github.com/graphdeco-inria/gaussian-splatting?tab=License-1-ov-file#readme) +submodule (distCUDA2 method) and now it is open source! + +## 📊 Benchmarks + +We have conducted a series of benchmarks, comparing deepsense implementation inference runtime to vanilla implementation +[graphdeco-inria 3D gaussian splatting implementation](https://github.com/graphdeco-inria/gaussian-splatting) +and to [nerfstudio's gsplat](https://github.com/nerfstudio-project/gsplat) implementation. + + +Below plots present inference time in ms measured for 120 frames as fly through a scene with zooming out to capture all +Gaussians. 6.1M Gaussians rendered in 1920x1080 with an NVIDIA 4070 Laptop GPU and 5.8M Gaussians rendered in 3840x2160 +with an NVIDIA 3090 GPU. + +![Inference time in ms. measured for 120 frames as fly through a scene with zooming out to capture all gaussians. 6.1M Gaussians rendered in 1920x1080 with NVIDIA 4070 Laptop GPU.](assets/bicycle_1920x1080_4070.png) +![Inference time in ms. measured for 120 frames as fly through a scene with zooming out to capture all gaussians. 5.8M Gaussians rendered in 3840x2160 with NVIDIA 3090 GPU.](assets/garden_3840x2160_3090.png) + + +For trained scenes, we have also compared PSNR (Peak Signal-to-Noise Ratio) for deepsense and gsplat methods to +Vanilla as ground truth. Using Vanilla's inria implementation, we rendered images when flying through a scene, +treating them as ground truth. For deepsense and gsplat implementations, we rendered scenes from the same camera +positions and compared them to Vanilla. This test shows how close our/gsplat implementation is to Vanilla's. Some +details are implementation-specific and result in slightly different outcomes, but both methods have very good PSNR in +this regard. Higher PSNR is better. + +📥 [Download](https://drive.google.com/file/d/1MADQzb6onTV6JBJQqTx8H9bN6It4Vcdj/view?usp=sharing) more benchmark plots from GDrive. + +![deepsense/gsplat PSNR to Vanilla. Bicycle Scene.](assets/psnr_to_vanilla/bicycle_1280x720_psnr.png) \ No newline at end of file diff --git a/__version__.py b/__version__.py new file mode 100644 index 0000000..67510db --- /dev/null +++ b/__version__.py @@ -0,0 +1,3 @@ +"""Version information.""" + +__version__ = "0.0.0" diff --git a/assets/bicycle_1920x1080_4070.png b/assets/bicycle_1920x1080_4070.png new file mode 100644 index 0000000..5406121 Binary files /dev/null and b/assets/bicycle_1920x1080_4070.png differ diff --git a/assets/garden_3840x2160_3090.png b/assets/garden_3840x2160_3090.png new file mode 100644 index 0000000..682021f Binary files /dev/null and b/assets/garden_3840x2160_3090.png differ diff --git a/assets/psnr_to_vanilla/bicycle_1280x720_psnr.png b/assets/psnr_to_vanilla/bicycle_1280x720_psnr.png new file mode 100644 index 0000000..59e7872 Binary files /dev/null and b/assets/psnr_to_vanilla/bicycle_1280x720_psnr.png differ diff --git a/assets/training_640_5fps.gif b/assets/training_640_5fps.gif new file mode 100644 index 0000000..9370c2a Binary files /dev/null and b/assets/training_640_5fps.gif differ diff --git a/conanfile.py b/conanfile.py new file mode 100644 index 0000000..c6eded4 --- /dev/null +++ b/conanfile.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 + +from conan import ConanFile +from conan.tools.cmake import CMakeToolchain, CMake, cmake_layout, CMakeDeps + + +class DsCudaRasterizerRecipe(ConanFile): + name = "ds_cuda_rasterizer" + version = "1.0" + package_type = "library" + + # Binary configuration + settings = "os", "compiler", "build_type", "arch" + options = {"shared": [True, False], "fPIC": [True, False]} + default_options = {"shared": False, "fPIC": True} + + # Sources are located in the same place as this recipe, copy them to the recipe + # exports_sources = "CMakeLists.txt", "src/*", "include/*", "tests/*" + + def requirements(self): + self.requires("gtest/1.14.0") + self.requires("thrust/1.17.2") + + def config_options(self): + if self.settings.os == "Windows": + self.options.rm_safe("fPIC") + + def configure(self): + if self.options.shared: + self.options.rm_safe("fPIC") + self.options["thrust/*"].device_system = "cuda" + + def layout(self): + cmake_layout(self) + + def generate(self): + deps = CMakeDeps(self) + deps.generate() + tc = CMakeToolchain(self) + tc.generate() + + def build(self): + cmake = CMake(self) + cmake.configure(variables={ + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON" + }) + cmake.build() + + def package(self): + cmake = CMake(self) + cmake.install() + + def package_info(self): + self.cpp_info.libs = ["ds_cuda_rasterizer"] diff --git a/cpp/ds_cuda_rasterizer/rasterizer_cuda.hpp b/cpp/ds_cuda_rasterizer/rasterizer_cuda.hpp new file mode 100644 index 0000000..4ef9b8d --- /dev/null +++ b/cpp/ds_cuda_rasterizer/rasterizer_cuda.hpp @@ -0,0 +1,11 @@ +#ifndef RASTERIZER_CUDA_HPP_ +#define RASTERIZER_CUDA_HPP_ + +void rasterizer_forward_core_deepsense(const float* means_3d, const float* shs, const float* opacities, + const float* scales, const float* rotations, int num_of_gaussians, + const float* view_matrix, const float* proj_matrix, const float* camera_position, + int image_width, int image_height, float tan_fovx, float tan_fovy, + float scale_modifier, const int max_sh_degree, const int sh_degree, + float* output_image); + +#endif // RASTERIZER_CUDA_HPP_ diff --git a/cpp/ds_cuda_rasterizer/rasterizer_torch.hpp b/cpp/ds_cuda_rasterizer/rasterizer_torch.hpp new file mode 100644 index 0000000..30701f7 --- /dev/null +++ b/cpp/ds_cuda_rasterizer/rasterizer_torch.hpp @@ -0,0 +1,37 @@ +#ifndef DS_RASTERIZER_HPP_ +#define DS_RASTERIZER_HPP_ + +#include +#include + +struct TorchRasterizationSettings +{ + int image_height; + int image_width; + float tanfovx; + float tanfovy; + torch::Tensor bg; + float scale_modifier; + torch::Tensor view_matrix; + torch::Tensor proj_matrix; + int sh_degree; + int max_sh_degree; + torch::Tensor campos; + bool prefiltered; + bool debug; +}; + +std::vector rasterizer_forward_deepsense(torch::Tensor means3D, torch::Tensor means2D, torch::Tensor shs, + torch::Tensor colors_precomp, torch::Tensor opacities, + torch::Tensor scales, torch::Tensor rotations, + torch::Tensor cov3D_precomp, + const TorchRasterizationSettings& settings); + +std::vector rasterizer_backward_deepsense( + torch::Tensor means_3d, torch::Tensor means2d, torch::Tensor shs, torch::Tensor colors_precomp, + torch::Tensor opacities, torch::Tensor scales, torch::Tensor rotations, torch::Tensor cov3D_precomp, + torch::Tensor radii, torch::Tensor gaussians, torch::Tensor colors_clamped, torch::Tensor final_Ts, + torch::Tensor final_index, torch::Tensor cov3D, torch::Tensor indices, torch::Tensor tile_indices, + torch::Tensor grad_image, const TorchRasterizationSettings& settings); + +#endif // DS_RASTERIZER_HPP_ diff --git a/cpp/gsplat/backward.cu b/cpp/gsplat/backward.cu new file mode 100644 index 0000000..dd4ee24 --- /dev/null +++ b/cpp/gsplat/backward.cu @@ -0,0 +1,393 @@ +#include "backward.cuh" +#include "helpers.cuh" +#include +#include +#include +namespace cg = cooperative_groups; + +/* clang-format off */ + +inline __device__ void warpSum3(float3& val, cg::thread_block_tile<32>& tile){ + val.x = cg::reduce(tile, val.x, cg::plus()); + val.y = cg::reduce(tile, val.y, cg::plus()); + val.z = cg::reduce(tile, val.z, cg::plus()); +} + +inline __device__ void warpSum2(float2& val, cg::thread_block_tile<32>& tile){ + val.x = cg::reduce(tile, val.x, cg::plus()); + val.y = cg::reduce(tile, val.y, cg::plus()); +} + +inline __device__ void warpSum(float& val, cg::thread_block_tile<32>& tile){ + val = cg::reduce(tile, val, cg::plus()); +} + +__global__ void rasterize_backward_kernel( + const dim3 tile_bounds, + const dim3 img_size, + const int32_t* __restrict__ gaussian_ids_sorted, + const int2* __restrict__ tile_bins, + const Gaussian* __restrict__ gaussians, + const float* __restrict__ opacities, + const float3& __restrict__ background, + const float* __restrict__ final_Ts, + const int* __restrict__ final_index, + const float3* __restrict__ v_output, + const float* __restrict__ v_output_alpha, + float2* __restrict__ v_xy, + float3* __restrict__ v_conic, + float3* __restrict__ v_rgb, + float* __restrict__ v_opacity +) { + auto block = cg::this_thread_block(); + int32_t tile_id = + block.group_index().y * tile_bounds.x + block.group_index().x; + unsigned i = + block.group_index().y * block.group_dim().y + block.thread_index().y; + unsigned j = + block.group_index().x * block.group_dim().x + block.thread_index().x; + + const float px = (float)j + 0.5; + const float py = (float)i + 0.5; + // clamp this value to the last pixel + const int32_t pix_id = min(i * img_size.x + j, img_size.x * img_size.y - 1); + + // keep not rasterizing threads around for reading data + const bool inside = (i < img_size.y && j < img_size.x); + + // this is the T AFTER the last gaussian in this pixel + float T_final = final_Ts[pix_id]; + float T = T_final; + // the contribution from gaussians behind the current one + float3 buffer = {0.f, 0.f, 0.f}; + // index of last gaussian to contribute to this pixel + const int bin_final = inside? final_index[pix_id] : 0; + + // have all threads in tile process the same gaussians in batches + // first collect gaussians between range.x and range.y in batches + // which gaussians to look through in this tile + const int2 range = tile_bins[tile_id]; + const int block_size = block.size(); + const int num_batches = (range.y - range.x + block_size - 1) / block_size; + + __shared__ int32_t id_batch[MAX_BLOCK_SIZE]; + __shared__ float3 xy_opacity_batch[MAX_BLOCK_SIZE]; + __shared__ float3 conic_batch[MAX_BLOCK_SIZE]; + __shared__ float3 rgbs_batch[MAX_BLOCK_SIZE]; + + // df/d_out for this pixel + const float3 v_out = v_output[pix_id]; + const float v_out_alpha = v_output_alpha[pix_id]; + + // collect and process batches of gaussians + // each thread loads one gaussian at a time before rasterizing + const int tr = block.thread_rank(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + const int warp_bin_final = cg::reduce(warp, bin_final, cg::greater()); + for (int b = 0; b < num_batches; ++b) { + // resync all threads before writing next batch of shared mem + block.sync(); + + // each thread fetch 1 gaussian from back to front + // 0 index will be furthest back in batch + // index of gaussian to load + // batch end is the index of the last gaussian in the batch + const int batch_end = range.y - 1 - block_size * b; + int batch_size = min(block_size, batch_end + 1 - range.x); + const int idx = batch_end - tr; + if (idx >= range.x) { + int32_t g_id = gaussian_ids_sorted[idx]; + id_batch[tr] = g_id; + const auto& gaussian = gaussians[g_id]; + const float2 xy = {gaussian.means_2d.x, gaussian.means_2d.y}; + const float opac = opacities[g_id]; + xy_opacity_batch[tr] = {xy.x, xy.y, opac}; + conic_batch[tr] = gaussian.conic; + rgbs_batch[tr] = gaussian.color; + } + // wait for other threads to collect the gaussians in batch + block.sync(); + // process gaussians in the current batch for this pixel + // 0 index is the furthest back gaussian in the batch + for (int t = max(0,batch_end - warp_bin_final); t < batch_size; ++t) { + int valid = inside; + if (batch_end - t > bin_final) { + valid = 0; + } + float alpha; + float opac; + float2 delta; + float3 conic; + float vis; + if(valid){ + conic = conic_batch[t]; + float3 xy_opac = xy_opacity_batch[t]; + opac = xy_opac.z; + delta = {xy_opac.x - px, xy_opac.y - py}; + float sigma = 0.5f * (conic.x * delta.x * delta.x + + conic.z * delta.y * delta.y) + + conic.y * delta.x * delta.y; + vis = __expf(-sigma); + alpha = min(0.99f, opac * vis); + if (sigma < 0.f || alpha < 1.f / 255.f) { + valid = 0; + } + } + // if all threads are inactive in this warp, skip this loop + if(!warp.any(valid)){ + continue; + } + float3 v_rgb_local = {0.f, 0.f, 0.f}; + float3 v_conic_local = {0.f, 0.f, 0.f}; + float2 v_xy_local = {0.f, 0.f}; + float v_opacity_local = 0.f; + //initialize everything to 0, only set if the lane is valid + if(valid){ + // compute the current T for this gaussian + float ra = 1.f / (1.f - alpha); + T *= ra; + // update v_rgb for this gaussian + const float fac = alpha * T; + float v_alpha = 0.f; + v_rgb_local = {fac * v_out.x, fac * v_out.y, fac * v_out.z}; + + const float3 rgb = rgbs_batch[t]; + // contribution from this pixel + v_alpha += (rgb.x * T - buffer.x * ra) * v_out.x; + v_alpha += (rgb.y * T - buffer.y * ra) * v_out.y; + v_alpha += (rgb.z * T - buffer.z * ra) * v_out.z; + + v_alpha += T_final * ra * v_out_alpha; + // contribution from background pixel + v_alpha += -T_final * ra * background.x * v_out.x; + v_alpha += -T_final * ra * background.y * v_out.y; + v_alpha += -T_final * ra * background.z * v_out.z; + // update the running sum + buffer.x += rgb.x * fac; + buffer.y += rgb.y * fac; + buffer.z += rgb.z * fac; + + const float v_sigma = -opac * vis * v_alpha; + v_conic_local = {0.5f * v_sigma * delta.x * delta.x, + v_sigma * delta.x * delta.y, + 0.5f * v_sigma * delta.y * delta.y}; + v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y), + v_sigma * (conic.y * delta.x + conic.z * delta.y)}; + v_opacity_local = vis * v_alpha; + } + warpSum3(v_rgb_local, warp); + warpSum3(v_conic_local, warp); + warpSum2(v_xy_local, warp); + warpSum(v_opacity_local, warp); + if (warp.thread_rank() == 0) { + int32_t g = id_batch[t]; + float* v_rgb_ptr = (float*)(v_rgb); + atomicAdd(v_rgb_ptr + 3*g + 0, v_rgb_local.x); + atomicAdd(v_rgb_ptr + 3*g + 1, v_rgb_local.y); + atomicAdd(v_rgb_ptr + 3*g + 2, v_rgb_local.z); + + float* v_conic_ptr = (float*)(v_conic); + atomicAdd(v_conic_ptr + 3*g + 0, v_conic_local.x); + atomicAdd(v_conic_ptr + 3*g + 1, v_conic_local.y); + atomicAdd(v_conic_ptr + 3*g + 2, v_conic_local.z); + + float* v_xy_ptr = (float*)(v_xy); + atomicAdd(v_xy_ptr + 2*g + 0, v_xy_local.x); + atomicAdd(v_xy_ptr + 2*g + 1, v_xy_local.y); + + atomicAdd(v_opacity + g, v_opacity_local); + } + } + } +} + + +__device__ float calculate_compensation(const float3& conic) +{ + const float3 conic_no_add = {conic.x - 0.3f, conic.y, conic.z - 0.3f}; + const auto det_a = conic.x * conic.z - conic.y * conic.y; + const auto det_b = conic_no_add.x * conic_no_add.z - conic_no_add.y * conic_no_add.y; + const auto result = det_a / det_b; + + return sqrtf(fmaxf(result, 0.0f)); +} + + +__global__ void project_gaussians_backward_kernel( + const int num_points, + const float3* __restrict__ means3d, + const float3* __restrict__ scales, + const float glob_scale, + const float4* __restrict__ quats, + const float* __restrict__ viewmat, + const float4 intrins, + const dim3 img_size, + const Gaussian* __restrict__ gaussians, + const float* __restrict__ cov3d, + const int* __restrict__ radii, + const float2* __restrict__ v_xy, + const float* __restrict__ v_depth, + const float3* __restrict__ v_conic, + float3* __restrict__ v_cov2d, + float* __restrict__ v_cov3d, + float3* __restrict__ v_mean3d, + float3* __restrict__ v_scale, + float4* __restrict__ v_quat +) +{ + unsigned idx = cg::this_grid().thread_rank(); // idx of thread within grid + if (idx >= num_points || radii[idx] <= 0) { + return; + } + + + float3 p_world = means3d[idx]; + float fx = intrins.x; + float fy = intrins.y; + float3 p_view = transform_4x3(viewmat, p_world); + // get v_mean3d from v_xy + v_mean3d[idx] = transform_4x3_rot_only_transposed( + viewmat, + project_pix_vjp({fx, fy}, p_view, v_xy[idx])); + + // get z gradient contribution to mean3d gradient + // z = viemwat[8] * mean3d.x + viewmat[9] * mean3d.y + viewmat[10] * + // mean3d.z + viewmat[11] + float v_z = v_depth[idx]; + v_mean3d[idx].x += viewmat[8] * v_z; + v_mean3d[idx].y += viewmat[9] * v_z; + v_mean3d[idx].z += viewmat[10] * v_z; + + const auto current_conic = gaussians[idx].conic; + const auto current_v_compensation = 0.0f; + const auto current_compensation = calculate_compensation(current_conic); + // get v_cov2d + cov2d_to_conic_vjp(current_conic, v_conic[idx], v_cov2d[idx]); + cov2d_to_compensation_vjp(current_compensation, current_conic, current_v_compensation, v_cov2d[idx]); + // get v_cov3d (and v_mean3d contribution) + project_cov3d_ewa_vjp( + p_world, + &(cov3d[6 * idx]), + viewmat, + fx, + fy, + v_cov2d[idx], + v_mean3d[idx], + &(v_cov3d[6 * idx]) + ); + // get v_scale and v_quat + scale_rot_to_cov3d_vjp( + scales[idx], + glob_scale, + quats[idx], + &(v_cov3d[6 * idx]), + v_scale[idx], + v_quat[idx] + ); + +} + +// output space: 2D covariance, input space: cov3d +__device__ void project_cov3d_ewa_vjp( + const float3& __restrict__ mean3d, + const float* __restrict__ cov3d, + const float* __restrict__ viewmat, + const float fx, + const float fy, + const float3& __restrict__ v_cov2d, + float3& __restrict__ v_mean3d, + float* __restrict__ v_cov3d +) { + // viewmat is row major, glm is column major + // upper 3x3 submatrix + // clang-format off + glm::mat3 W = glm::mat3( + viewmat[0], viewmat[4], viewmat[8], + viewmat[1], viewmat[5], viewmat[9], + viewmat[2], viewmat[6], viewmat[10] + ); + // clang-format on + glm::vec3 p = glm::vec3(viewmat[3], viewmat[7], viewmat[11]); + glm::vec3 t = W * glm::vec3(mean3d.x, mean3d.y, mean3d.z) + p; + float rz = 1.f / t.z; + float rz2 = rz * rz; + + // column major + // we only care about the top 2x2 submatrix + // clang-format off + glm::mat3 J = glm::mat3( + fx * rz, 0.f, 0.f, + 0.f, fy * rz, 0.f, + -fx * t.x * rz2, -fy * t.y * rz2, 0.f + ); + glm::mat3 V = glm::mat3( + cov3d[0], cov3d[1], cov3d[2], + cov3d[1], cov3d[3], cov3d[4], + cov3d[2], cov3d[4], cov3d[5] + ); + // cov = T * V * Tt; G = df/dcov = v_cov + // -> d/dV = Tt * G * T + // -> df/dT = G * T * Vt + Gt * T * V + glm::mat3 v_cov = glm::mat3( + v_cov2d.x, 0.5f * v_cov2d.y, 0.f, + 0.5f * v_cov2d.y, v_cov2d.z, 0.f, + 0.f, 0.f, 0.f + ); + // clang-format on + + glm::mat3 T = J * W; + glm::mat3 Tt = glm::transpose(T); + glm::mat3 Vt = glm::transpose(V); + glm::mat3 v_V = Tt * v_cov * T; + glm::mat3 v_T = v_cov * T * Vt + glm::transpose(v_cov) * T * V; + + // vjp of cov3d parameters + // v_cov3d_i = v_V : dV/d_cov3d_i + // where : is frobenius inner product + v_cov3d[0] = v_V[0][0]; + v_cov3d[1] = v_V[0][1] + v_V[1][0]; + v_cov3d[2] = v_V[0][2] + v_V[2][0]; + v_cov3d[3] = v_V[1][1]; + v_cov3d[4] = v_V[1][2] + v_V[2][1]; + v_cov3d[5] = v_V[2][2]; + + // compute df/d_mean3d + // T = J * W + glm::mat3 v_J = v_T * glm::transpose(W); + float rz3 = rz2 * rz; + glm::vec3 v_t = glm::vec3(-fx * rz2 * v_J[2][0], -fy * rz2 * v_J[2][1], + -fx * rz2 * v_J[0][0] + 2.f * fx * t.x * rz3 * v_J[2][0] - fy * rz2 * v_J[1][1] + + 2.f * fy * t.y * rz3 * v_J[2][1]); + // printf("v_t %.2f %.2f %.2f\n", v_t[0], v_t[1], v_t[2]); + // printf("W %.2f %.2f %.2f\n", W[0][0], W[0][1], W[0][2]); + v_mean3d.x += (float)glm::dot(v_t, W[0]); + v_mean3d.y += (float)glm::dot(v_t, W[1]); + v_mean3d.z += (float)glm::dot(v_t, W[2]); +} + +// given cotangent v in output space (e.g. d_L/d_cov3d) in R(6) +// compute vJp for scale and rotation +__device__ void scale_rot_to_cov3d_vjp(const float3 scale, const float glob_scale, const float4 quat, + const float* __restrict__ v_cov3d, float3& __restrict__ v_scale, + float4& __restrict__ v_quat) +{ + // cov3d is upper triangular elements of matrix + // off-diagonal elements count grads from both ij and ji elements, + // must halve when expanding back into symmetric matrix + glm::mat3 v_V = glm::mat3(v_cov3d[0], 0.5 * v_cov3d[1], 0.5 * v_cov3d[2], 0.5 * v_cov3d[1], v_cov3d[3], + 0.5 * v_cov3d[4], 0.5 * v_cov3d[2], 0.5 * v_cov3d[4], v_cov3d[5]); + glm::mat3 R = quat_to_rotmat(quat); + glm::mat3 S = scale_to_mat(scale, glob_scale); + glm::mat3 M = R * S; + // https://math.stackexchange.com/a/3850121 + // for D = W * X, G = df/dD + // df/dW = G * XT, df/dX = WT * G + glm::mat3 v_M = 2.f * v_V * M; + // glm::mat3 v_S = glm::transpose(R) * v_M; + v_scale.x = (float)glm::dot(R[0], v_M[0]) * glob_scale; + v_scale.y = (float)glm::dot(R[1], v_M[1]) * glob_scale; + v_scale.z = (float)glm::dot(R[2], v_M[2]) * glob_scale; + + glm::mat3 v_R = v_M * S; + v_quat = quat_to_rotmat_vjp(quat, v_R); +} diff --git a/cpp/gsplat/backward.cuh b/cpp/gsplat/backward.cuh new file mode 100644 index 0000000..6720819 --- /dev/null +++ b/cpp/gsplat/backward.cuh @@ -0,0 +1,69 @@ +#include +#include +#include + +#include "../rasterizer_kernels.cuh" + +// TODO +/* clang-format off */ + +__global__ void project_gaussians_backward_kernel( + const int num_points, + const float3* __restrict__ means3d, + const float3* __restrict__ scales, + const float glob_scale, + const float4* __restrict__ quats, + const float* __restrict__ viewmat, + const float4 intrins, + const dim3 img_size, + const Gaussian* __restrict__ gaussians, + const float* __restrict__ cov3d, + const int* __restrict__ radii, + const float2* __restrict__ v_xy, + const float* __restrict__ v_depth, + const float3* __restrict__ v_conic, + float3* __restrict__ v_cov2d, + float* __restrict__ v_cov3d, + float3* __restrict__ v_mean3d, + float3* __restrict__ v_scale, + float4* __restrict__ v_quat +); + + +__global__ void rasterize_backward_kernel( + const dim3 tile_bounds, + const dim3 img_size, + const int32_t* __restrict__ gaussian_ids_sorted, + const int2* __restrict__ tile_bins, + const Gaussian* __restrict__ gaussians, + const float* __restrict__ opacities, + const float3& __restrict__ background, + const float* __restrict__ final_Ts, + const int* __restrict__ final_index, + const float3* __restrict__ v_output, + const float* __restrict__ v_output_alpha, + float2* __restrict__ v_xy, + float3* __restrict__ v_conic, + float3* __restrict__ v_rgb, + float* __restrict__ v_opacity +); + +__device__ void project_cov3d_ewa_vjp( + const float3 &mean3d, + const float *cov3d, + const float *viewmat, + const float fx, + const float fy, + const float3 &v_cov2d, + float3 &v_mean3d, + float *v_cov3d +); + +__device__ void scale_rot_to_cov3d_vjp( + const float3 scale, + const float glob_scale, + const float4 quat, + const float *v_cov3d, + float3 &v_scale, + float4 &v_quat +); diff --git a/cpp/gsplat/bindings.cu b/cpp/gsplat/bindings.cu new file mode 100644 index 0000000..9d999fc --- /dev/null +++ b/cpp/gsplat/bindings.cu @@ -0,0 +1,194 @@ +/* clang-format off */ + +#include "backward.cuh" +#include "bindings.h" +#include "helpers.cuh" +#include "sh.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +namespace cg = cooperative_groups; + +torch::Tensor compute_sh_backward_tensor( + const unsigned num_points, + const unsigned degree, + const unsigned degrees_to_use, + torch::Tensor &viewdirs, + torch::Tensor &v_colors) { + DEVICE_GUARD(viewdirs); + if (viewdirs.ndimension() != 2 || viewdirs.size(0) != num_points || + viewdirs.size(1) != 3) { + AT_ERROR("viewdirs must have dimensions (N, 3)"); + } + if (v_colors.ndimension() != 2 || v_colors.size(0) != num_points || + v_colors.size(1) != 3) { + AT_ERROR("v_colors must have dimensions (N, 3)"); + } + unsigned num_bases = num_sh_bases(degree); + torch::Tensor v_coeffs = + torch::zeros({num_points, num_bases, 3}, v_colors.options()); + compute_sh_backward_kernel<<< + (num_points + N_THREADS - 1) / N_THREADS, + N_THREADS>>>( + num_points, + degree, + degrees_to_use, + (float3 *)viewdirs.contiguous().data_ptr(), + v_colors.contiguous().data_ptr(), + v_coeffs.contiguous().data_ptr() + ); + return v_coeffs; +} + + + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +project_gaussians_backward_tensor( + const int num_points, + torch::Tensor &means3d, + torch::Tensor &scales, + const float glob_scale, + torch::Tensor &quats, + torch::Tensor &viewmat, + const float fx, + const float fy, + const float cx, + const float cy, + const unsigned img_height, + const unsigned img_width, + torch::Tensor & gaussians, + torch::Tensor &cov3d, + torch::Tensor &radii, + torch::Tensor &compensation, + torch::Tensor &v_xy, + torch::Tensor &v_depth, + torch::Tensor &v_conic +) +{ + DEVICE_GUARD(means3d); + dim3 img_size_dim3; + img_size_dim3.x = img_width; + img_size_dim3.y = img_height; + + float4 intrins = {fx, fy, cx, cy}; + + const auto num_cov3d = num_points * 6; + + // Triangular covariance. + torch::Tensor v_cov2d = + torch::zeros({num_points, 3}, means3d.options().dtype(torch::kFloat32)); + torch::Tensor v_cov3d = + torch::zeros({num_points, 6}, means3d.options().dtype(torch::kFloat32)); + torch::Tensor v_mean3d = + torch::zeros({num_points, 3}, means3d.options().dtype(torch::kFloat32)); + torch::Tensor v_scale = + torch::zeros({num_points, 3}, means3d.options().dtype(torch::kFloat32)); + torch::Tensor v_quat = + torch::zeros({num_points, 4}, means3d.options().dtype(torch::kFloat32)); + + project_gaussians_backward_kernel<<< + (num_points + N_THREADS - 1) / N_THREADS, + N_THREADS>>>( + num_points, + (float3 *)means3d.contiguous().data_ptr(), + (float3 *)scales.contiguous().data_ptr(), + glob_scale, + (float4 *)quats.contiguous().data_ptr(), + viewmat.contiguous().data_ptr(), + intrins, + img_size_dim3, + reinterpret_cast(gaussians.data_ptr()), + cov3d.contiguous().data_ptr(), + radii.contiguous().data_ptr(), + (float2 *)v_xy.contiguous().data_ptr(), + v_depth.contiguous().data_ptr(), + (float3 *)v_conic.contiguous().data_ptr(), + // Outputs. + (float3 *)v_cov2d.contiguous().data_ptr(), + v_cov3d.contiguous().data_ptr(), + (float3 *)v_mean3d.contiguous().data_ptr(), + (float3 *)v_scale.contiguous().data_ptr(), + (float4 *)v_quat.contiguous().data_ptr() + ); + + return std::make_tuple(v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat); + +} + +std:: + tuple< + torch::Tensor, // dL_dxy + torch::Tensor, // dL_dconic + torch::Tensor, // dL_dcolors + torch::Tensor // dL_dopacity + > + rasterize_backward_tensor( + const unsigned img_height, + const unsigned img_width, + const unsigned block_width, + const torch::Tensor &gaussians_ids_sorted, + const torch::Tensor &tile_bins, + const torch::Tensor &gaussians, + const torch::Tensor &opacities, + const torch::Tensor &background, + const torch::Tensor &final_Ts, + const torch::Tensor &final_idx, + const torch::Tensor &v_output, // dL_dout_color + const torch::Tensor &v_output_alpha // dL_dout_alpha + ) { + + CHECK_INPUT(gaussians); + + const int num_points = opacities.size(0); + const dim3 tile_bounds = { + (img_width + block_width - 1) / block_width, + (img_height + block_width - 1) / block_width, + 1 + }; + const dim3 block(block_width, block_width, 1); + const dim3 img_size = {img_width, img_height, 1}; + const int channels = 3; //colors.size(1); + + torch::Tensor v_xy = torch::zeros({num_points, 2}, opacities.options()); + torch::Tensor v_conic = torch::zeros({num_points, 3}, opacities.options()); + torch::Tensor v_colors = + torch::zeros({num_points, channels}, opacities.options()); + torch::Tensor v_opacity = torch::zeros({num_points, 1}, opacities.options()); + + const Gaussian* gaussian_data_ptr = reinterpret_cast(gaussians.data_ptr()); + + rasterize_backward_kernel<<>>( + tile_bounds, + img_size, + gaussians_ids_sorted.contiguous().data_ptr(), + (int2 *)tile_bins.contiguous().data_ptr(), + gaussian_data_ptr, + opacities.contiguous().data_ptr(), + *(float3 *)background.contiguous().data_ptr(), + final_Ts.contiguous().data_ptr(), + final_idx.contiguous().data_ptr(), + (float3 *)v_output.contiguous().data_ptr(), + v_output_alpha.contiguous().data_ptr(), + (float2 *)v_xy.contiguous().data_ptr(), + (float3 *)v_conic.contiguous().data_ptr(), + (float3 *)v_colors.contiguous().data_ptr(), + v_opacity.contiguous().data_ptr() + ); + + return std::make_tuple(v_xy, v_conic, v_colors, v_opacity); +} diff --git a/cpp/gsplat/bindings.h b/cpp/gsplat/bindings.h new file mode 100644 index 0000000..65147f2 --- /dev/null +++ b/cpp/gsplat/bindings.h @@ -0,0 +1,76 @@ +/* clang-format off */ + +#include "cuda_runtime.h" +#include +#include +#include +#include +#include +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) +#define DEVICE_GUARD(_ten) \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten)); + +torch::Tensor compute_sh_backward_tensor( + unsigned num_points, + unsigned degree, + unsigned degrees_to_use, + torch::Tensor &viewdirs, + torch::Tensor &v_colors +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +project_gaussians_backward_tensor( + const int num_points, + torch::Tensor &means3d, + torch::Tensor &scales, + const float glob_scale, + torch::Tensor &quats, + torch::Tensor &viewmat, + const float fx, + const float fy, + const float cx, + const float cy, + const unsigned img_height, + const unsigned img_width, + torch::Tensor &cov3d, + torch::Tensor &radii, + torch::Tensor &conics, + torch::Tensor &compensation, + torch::Tensor &v_xy, + torch::Tensor &v_depth, + torch::Tensor &v_conic +); + +std:: + tuple< + torch::Tensor, // dL_dxy + torch::Tensor, // dL_dconic + torch::Tensor, // dL_dcolors + torch::Tensor // dL_dopacity + > + rasterize_backward_tensor( + const unsigned img_height, + const unsigned img_width, + const unsigned block_width, + const torch::Tensor &gaussians_ids_sorted, + const torch::Tensor &tile_bins, + const torch::Tensor &gaussians, + const torch::Tensor &opacities, + const torch::Tensor &background, + const torch::Tensor &final_Ts, + const torch::Tensor &final_idx, + const torch::Tensor &v_output, // dL_dout_color + const torch::Tensor &v_output_alpha + ); diff --git a/cpp/gsplat/config.h b/cpp/gsplat/config.h new file mode 100644 index 0000000..ad73a02 --- /dev/null +++ b/cpp/gsplat/config.h @@ -0,0 +1,17 @@ +#define MAX_BLOCK_SIZE ( 16 * 16 ) +#define N_THREADS 256 + +#define MAX_REGISTER_CHANNELS 3 + +#define CUDA_CALL(x) \ + do { \ + if ((x) != cudaSuccess) { \ + printf( \ + "Error at %s:%d - %s\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(cudaGetLastError()) \ + ); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) diff --git a/cpp/gsplat/helpers.cuh b/cpp/gsplat/helpers.cuh new file mode 100644 index 0000000..4a8b357 --- /dev/null +++ b/cpp/gsplat/helpers.cuh @@ -0,0 +1,231 @@ +#include "config.h" +#include +#include "third_party/glm/glm/glm.hpp" +#include "third_party/glm/glm/gtc/type_ptr.hpp" +#include + +inline __device__ void get_bbox( + const float2 center, + const float2 dims, + const dim3 img_size, + uint2 &bb_min, + uint2 &bb_max +) { + // get bounding box with center and dims, within bounds + // bounding box coords returned in tile coords, inclusive min, exclusive max + // clamp between 0 and tile bounds + bb_min.x = min(max(0, (int)(center.x - dims.x)), img_size.x); + bb_max.x = min(max(0, (int)(center.x + dims.x + 1)), img_size.x); + bb_min.y = min(max(0, (int)(center.y - dims.y)), img_size.y); + bb_max.y = min(max(0, (int)(center.y + dims.y + 1)), img_size.y); +} + +inline __device__ void get_tile_bbox( + const float2 pix_center, + const float pix_radius, + const dim3 tile_bounds, + uint2 &tile_min, + uint2 &tile_max, + const int block_size +) { + // gets gaussian dimensions in tile space, i.e. the span of a gaussian in + // tile_grid (image divided into tiles) + float2 tile_center = { + pix_center.x / (float)block_size, pix_center.y / (float)block_size + }; + float2 tile_radius = { + pix_radius / (float)block_size, pix_radius / (float)block_size + }; + get_bbox(tile_center, tile_radius, tile_bounds, tile_min, tile_max); +} + +inline __device__ bool +compute_cov2d_bounds(const float3 cov2d, float3 &conic, float &radius) { + // find eigenvalues of 2d covariance matrix + // expects upper triangular values of cov matrix as float3 + // then compute the radius and conic dimensions + // the conic is the inverse cov2d matrix, represented here with upper + // triangular values. + float det = cov2d.x * cov2d.z - cov2d.y * cov2d.y; + if (det == 0.f) + return false; + float inv_det = 1.f / det; + + // inverse of 2x2 cov2d matrix + conic.x = cov2d.z * inv_det; + conic.y = -cov2d.y * inv_det; + conic.z = cov2d.x * inv_det; + + float b = 0.5f * (cov2d.x + cov2d.z); + float v1 = b + sqrt(max(0.1f, b * b - det)); + float v2 = b - sqrt(max(0.1f, b * b - det)); + // take 3 sigma of covariance + radius = ceil(3.f * sqrt(max(v1, v2))); + return true; +} + +// compute vjp from df/d_conic to df/c_cov2d +inline __device__ void cov2d_to_conic_vjp( + const float3 &conic, const float3 &v_conic, float3 &v_cov2d +) { + // conic = inverse cov2d + // df/d_cov2d = -conic * df/d_conic * conic + glm::mat2 X = glm::mat2(conic.x, conic.y, conic.y, conic.z); + glm::mat2 G = glm::mat2(v_conic.x, v_conic.y / 2.f, v_conic.y / 2.f, v_conic.z); + glm::mat2 v_Sigma = -X * G * X; + v_cov2d.x = v_Sigma[0][0]; + v_cov2d.y = v_Sigma[1][0] + v_Sigma[0][1]; + v_cov2d.z = v_Sigma[1][1]; +} + +inline __device__ void cov2d_to_compensation_vjp( + const float compensation, const float3 &conic, const float v_compensation, float3 &v_cov2d +) { + // comp = sqrt(det(cov2d - 0.3 I) / det(cov2d)) + // conic = inverse(cov2d) + // df / d_cov2d = df / d comp * 0.5 / comp * [ d comp^2 / d cov2d ] + // d comp^2 / d cov2d = (1 - comp^2) * conic - 0.3 I * det(conic) + float inv_det = conic.x * conic.z - conic.y * conic.y; + float one_minus_sqr_comp = 1 - compensation * compensation; + float v_sqr_comp = v_compensation * 0.5 / (compensation + 1e-6); + v_cov2d.x += v_sqr_comp * (one_minus_sqr_comp * conic.x - 0.3 * inv_det); + v_cov2d.y += 2 * v_sqr_comp * (one_minus_sqr_comp * conic.y); + v_cov2d.z += v_sqr_comp * (one_minus_sqr_comp * conic.z - 0.3 * inv_det); +} + +// helper for applying R^T * p for a ROW MAJOR 4x3 matrix [R, t], ignoring t +inline __device__ float3 transform_4x3_rot_only_transposed(const float *mat, const float3 p) { + float3 out = { + mat[0] * p.x + mat[4] * p.y + mat[8] * p.z, + mat[1] * p.x + mat[5] * p.y + mat[9] * p.z, + mat[2] * p.x + mat[6] * p.y + mat[10] * p.z, + }; + return out; +} + +// helper for applying R * p + T, expect mat to be ROW MAJOR +inline __device__ float3 transform_4x3(const float *mat, const float3 p) { + float3 out = { + mat[0] * p.x + mat[1] * p.y + mat[2] * p.z + mat[3], + mat[4] * p.x + mat[5] * p.y + mat[6] * p.z + mat[7], + mat[8] * p.x + mat[9] * p.y + mat[10] * p.z + mat[11], + }; + return out; +} + +// helper to apply 4x4 transform to 3d vector, return homo coords +// expects mat to be ROW MAJOR +inline __device__ float4 transform_4x4(const float *mat, const float3 p) { + float4 out = { + mat[0] * p.x + mat[1] * p.y + mat[2] * p.z + mat[3], + mat[4] * p.x + mat[5] * p.y + mat[6] * p.z + mat[7], + mat[8] * p.x + mat[9] * p.y + mat[10] * p.z + mat[11], + mat[12] * p.x + mat[13] * p.y + mat[14] * p.z + mat[15], + }; + return out; +} + +inline __device__ float2 project_pix( + const float2 fxfy, const float3 p_view, const float2 pp +) { + float rw = 1.f / (p_view.z + 1e-6f); + float2 p_proj = { p_view.x * rw, p_view.y * rw }; + float2 p_pix = { p_proj.x * fxfy.x + pp.x, p_proj.y * fxfy.y + pp.y }; + return p_pix; +} + +// given v_xy_pix, get v_xyz +inline __device__ float3 project_pix_vjp( + const float2 fxfy, const float3 p_view, const float2 v_xy +) { + float rw = 1.f / (p_view.z + 1e-6f); + float2 v_proj = { fxfy.x * v_xy.x, fxfy.y * v_xy.y }; + float3 v_view = { + v_proj.x * rw, v_proj.y * rw, -(v_proj.x * p_view.x + v_proj.y * p_view.y) * rw * rw + }; + return v_view; +} + +inline __device__ glm::mat3 quat_to_rotmat(const float4 quat) { + // quat to rotation matrix + float w = quat.x; + float x = quat.y; + float y = quat.z; + float z = quat.w; + + // glm matrices are column-major + return glm::mat3( + 1.f - 2.f * (y * y + z * z), + 2.f * (x * y + w * z), + 2.f * (x * z - w * y), + 2.f * (x * y - w * z), + 1.f - 2.f * (x * x + z * z), + 2.f * (y * z + w * x), + 2.f * (x * z + w * y), + 2.f * (y * z - w * x), + 1.f - 2.f * (x * x + y * y) + ); +} + +inline __device__ float4 +quat_to_rotmat_vjp(const float4 quat, const glm::mat3 v_R) { + float w = quat.x; + float x = quat.y; + float y = quat.z; + float z = quat.w; + + float4 v_quat; + // v_R is COLUMN MAJOR + // w element stored in x field + v_quat.x = + 2.f * ( + // v_quat.w = 2.f * ( + x * (v_R[1][2] - v_R[2][1]) + y * (v_R[2][0] - v_R[0][2]) + + z * (v_R[0][1] - v_R[1][0]) + ); + // x element in y field + v_quat.y = + 2.f * + ( + // v_quat.x = 2.f * ( + -2.f * x * (v_R[1][1] + v_R[2][2]) + y * (v_R[0][1] + v_R[1][0]) + + z * (v_R[0][2] + v_R[2][0]) + w * (v_R[1][2] - v_R[2][1]) + ); + // y element in z field + v_quat.z = + 2.f * + ( + // v_quat.y = 2.f * ( + x * (v_R[0][1] + v_R[1][0]) - 2.f * y * (v_R[0][0] + v_R[2][2]) + + z * (v_R[1][2] + v_R[2][1]) + w * (v_R[2][0] - v_R[0][2]) + ); + // z element in w field + v_quat.w = + 2.f * + ( + // v_quat.z = 2.f * ( + x * (v_R[0][2] + v_R[2][0]) + y * (v_R[1][2] + v_R[2][1]) - + 2.f * z * (v_R[0][0] + v_R[1][1]) + w * (v_R[0][1] - v_R[1][0]) + ); + return v_quat; +} + +inline __device__ glm::mat3 +scale_to_mat(const float3 scale, const float glob_scale) { + glm::mat3 S = glm::mat3(1.f); + S[0][0] = glob_scale * scale.x; + S[1][1] = glob_scale * scale.y; + S[2][2] = glob_scale * scale.z; + return S; +} + +// device helper for culling near points +inline __device__ bool clip_near_plane( + const float3 p, const float *viewmat, float3 &p_view, float thresh +) { + p_view = transform_4x3(viewmat, p); + if (p_view.z <= thresh) { + return true; + } + return false; +} diff --git a/cpp/gsplat/sh.cuh b/cpp/gsplat/sh.cuh new file mode 100644 index 0000000..0036e0c --- /dev/null +++ b/cpp/gsplat/sh.cuh @@ -0,0 +1,203 @@ +#include +#include +#define CHANNELS 3 +namespace cg = cooperative_groups; + +__device__ __constant__ float SH_C0 = 0.28209479177387814f; +__device__ __constant__ float SH_C1 = 0.4886025119029199f; +__device__ __constant__ float SH_C2[] = {1.0925484305920792f, -1.0925484305920792f, 0.31539156525252005f, + -1.0925484305920792f, 0.5462742152960396f}; +__device__ __constant__ float SH_C3[] = {-0.5900435899266435f, 2.890611442640554f, -0.4570457994644658f, + 0.3731763325901154f, -0.4570457994644658f, 1.445305721320277f, + -0.5900435899266435f}; +__device__ __constant__ float SH_C4[] = {2.5033429417967046f, -1.7701307697799304, 0.9461746957575601f, + -0.6690465435572892f, 0.10578554691520431f, -0.6690465435572892f, + 0.47308734787878004f, -1.7701307697799304f, 0.6258357354491761f}; + +// This function is used in both host and device code +__host__ __device__ unsigned num_sh_bases(const unsigned degree) +{ + if(degree == 0) + return 1; + if(degree == 1) + return 4; + if(degree == 2) + return 9; + if(degree == 3) + return 16; + return 25; +} + +__device__ void sh_coeffs_to_color(const unsigned degree, const float3& viewdir, const float* coeffs, float* colors) +{ + // Expects v_colors to be len CHANNELS + // and v_coeffs to be num_bases * CHANNELS + for(int c = 0; c < CHANNELS; ++c) + { + colors[c] = SH_C0 * coeffs[c]; + } + if(degree < 1) + { + return; + } + + float norm = sqrt(viewdir.x * viewdir.x + viewdir.y * viewdir.y + viewdir.z * viewdir.z); + float x = viewdir.x / norm; + float y = viewdir.y / norm; + float z = viewdir.z / norm; + + float xx = x * x; + float xy = x * y; + float xz = x * z; + float yy = y * y; + float yz = y * z; + float zz = z * z; + // expects CHANNELS * num_bases coefficients + // supports up to num_bases = 25 + for(int c = 0; c < CHANNELS; ++c) + { + colors[c] += + SH_C1 * (-y * coeffs[1 * CHANNELS + c] + z * coeffs[2 * CHANNELS + c] - x * coeffs[3 * CHANNELS + c]); + if(degree < 2) + { + continue; + } + colors[c] += (SH_C2[0] * xy * coeffs[4 * CHANNELS + c] + SH_C2[1] * yz * coeffs[5 * CHANNELS + c] + + SH_C2[2] * (2.f * zz - xx - yy) * coeffs[6 * CHANNELS + c] + + SH_C2[3] * xz * coeffs[7 * CHANNELS + c] + SH_C2[4] * (xx - yy) * coeffs[8 * CHANNELS + c]); + if(degree < 3) + { + continue; + } + colors[c] += + (SH_C3[0] * y * (3.f * xx - yy) * coeffs[9 * CHANNELS + c] + SH_C3[1] * xy * z * coeffs[10 * CHANNELS + c] + + SH_C3[2] * y * (4.f * zz - xx - yy) * coeffs[11 * CHANNELS + c] + + SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy) * coeffs[12 * CHANNELS + c] + + SH_C3[4] * x * (4.f * zz - xx - yy) * coeffs[13 * CHANNELS + c] + + SH_C3[5] * z * (xx - yy) * coeffs[14 * CHANNELS + c] + + SH_C3[6] * x * (xx - 3.f * yy) * coeffs[15 * CHANNELS + c]); + if(degree < 4) + { + continue; + } + colors[c] += (SH_C4[0] * xy * (xx - yy) * coeffs[16 * CHANNELS + c] + + SH_C4[1] * yz * (3.f * xx - yy) * coeffs[17 * CHANNELS + c] + + SH_C4[2] * xy * (7.f * zz - 1.f) * coeffs[18 * CHANNELS + c] + + SH_C4[3] * yz * (7.f * zz - 3.f) * coeffs[19 * CHANNELS + c] + + SH_C4[4] * (zz * (35.f * zz - 30.f) + 3.f) * coeffs[20 * CHANNELS + c] + + SH_C4[5] * xz * (7.f * zz - 3.f) * coeffs[21 * CHANNELS + c] + + SH_C4[6] * (xx - yy) * (7.f * zz - 1.f) * coeffs[22 * CHANNELS + c] + + SH_C4[7] * xz * (xx - 3.f * yy) * coeffs[23 * CHANNELS + c] + + SH_C4[8] * (xx * (xx - 3.f * yy) - yy * (3.f * xx - yy)) * coeffs[24 * CHANNELS + c]); + } +} + +__device__ void sh_coeffs_to_color_vjp(const unsigned degree, const float3& viewdir, const float* v_colors, + float* v_coeffs) +{ +// Expects v_colors to be len CHANNELS +// and v_coeffs to be num_bases * CHANNELS +#pragma unroll + for(int c = 0; c < CHANNELS; ++c) + { + v_coeffs[c] = SH_C0 * v_colors[c]; + } + if(degree < 1) + { + return; + } + + float norm = sqrt(viewdir.x * viewdir.x + viewdir.y * viewdir.y + viewdir.z * viewdir.z); + float x = viewdir.x / norm; + float y = viewdir.y / norm; + float z = viewdir.z / norm; + + float xx = x * x; + float xy = x * y; + float xz = x * z; + float yy = y * y; + float yz = y * z; + float zz = z * z; + +#pragma unroll + for(int c = 0; c < CHANNELS; ++c) + { + float v1 = -SH_C1 * y; + float v2 = SH_C1 * z; + float v3 = -SH_C1 * x; + v_coeffs[1 * CHANNELS + c] = v1 * v_colors[c]; + v_coeffs[2 * CHANNELS + c] = v2 * v_colors[c]; + v_coeffs[3 * CHANNELS + c] = v3 * v_colors[c]; + if(degree < 2) + { + continue; + } + float v4 = SH_C2[0] * xy; + float v5 = SH_C2[1] * yz; + float v6 = SH_C2[2] * (2.f * zz - xx - yy); + float v7 = SH_C2[3] * xz; + float v8 = SH_C2[4] * (xx - yy); + v_coeffs[4 * CHANNELS + c] = v4 * v_colors[c]; + v_coeffs[5 * CHANNELS + c] = v5 * v_colors[c]; + v_coeffs[6 * CHANNELS + c] = v6 * v_colors[c]; + v_coeffs[7 * CHANNELS + c] = v7 * v_colors[c]; + v_coeffs[8 * CHANNELS + c] = v8 * v_colors[c]; + if(degree < 3) + { + continue; + } + float v9 = SH_C3[0] * y * (3.f * xx - yy); + float v10 = SH_C3[1] * xy * z; + float v11 = SH_C3[2] * y * (4.f * zz - xx - yy); + float v12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy); + float v13 = SH_C3[4] * x * (4.f * zz - xx - yy); + float v14 = SH_C3[5] * z * (xx - yy); + float v15 = SH_C3[6] * x * (xx - 3.f * yy); + v_coeffs[9 * CHANNELS + c] = v9 * v_colors[c]; + v_coeffs[10 * CHANNELS + c] = v10 * v_colors[c]; + v_coeffs[11 * CHANNELS + c] = v11 * v_colors[c]; + v_coeffs[12 * CHANNELS + c] = v12 * v_colors[c]; + v_coeffs[13 * CHANNELS + c] = v13 * v_colors[c]; + v_coeffs[14 * CHANNELS + c] = v14 * v_colors[c]; + v_coeffs[15 * CHANNELS + c] = v15 * v_colors[c]; + if(degree < 4) + { + continue; + } + float v16 = SH_C4[0] * xy * (xx - yy); + float v17 = SH_C4[1] * yz * (3.f * xx - yy); + float v18 = SH_C4[2] * xy * (7.f * zz - 1.f); + float v19 = SH_C4[3] * yz * (7.f * zz - 3.f); + float v20 = SH_C4[4] * (zz * (35.f * zz - 30.f) + 3.f); + float v21 = SH_C4[5] * xz * (7.f * zz - 3.f); + float v22 = SH_C4[6] * (xx - yy) * (7.f * zz - 1.f); + float v23 = SH_C4[7] * xz * (xx - 3.f * yy); + float v24 = SH_C4[8] * (xx * (xx - 3.f * yy) - yy * (3.f * xx - yy)); + v_coeffs[16 * CHANNELS + c] = v16 * v_colors[c]; + v_coeffs[17 * CHANNELS + c] = v17 * v_colors[c]; + v_coeffs[18 * CHANNELS + c] = v18 * v_colors[c]; + v_coeffs[19 * CHANNELS + c] = v19 * v_colors[c]; + v_coeffs[20 * CHANNELS + c] = v20 * v_colors[c]; + v_coeffs[21 * CHANNELS + c] = v21 * v_colors[c]; + v_coeffs[22 * CHANNELS + c] = v22 * v_colors[c]; + v_coeffs[23 * CHANNELS + c] = v23 * v_colors[c]; + v_coeffs[24 * CHANNELS + c] = v24 * v_colors[c]; + } +} + +__global__ void compute_sh_backward_kernel(const unsigned num_points, const unsigned degree, + const unsigned degrees_to_use, const float3* __restrict__ viewdirs, + const float* __restrict__ v_colors, float* __restrict__ v_coeffs) +{ + unsigned idx = cg::this_grid().thread_rank(); + if(idx >= num_points) + { + return; + } + const unsigned num_channels = 3; + unsigned num_bases = num_sh_bases(degree); + unsigned idx_sh = num_bases * num_channels * idx; + unsigned idx_col = num_channels * idx; + + sh_coeffs_to_color_vjp(degrees_to_use, viewdirs[idx], &(v_colors[idx_col]), &(v_coeffs[idx_sh])); +} diff --git a/cpp/gsplat/third_party/glm b/cpp/gsplat/third_party/glm new file mode 160000 index 0000000..33b4a62 --- /dev/null +++ b/cpp/gsplat/third_party/glm @@ -0,0 +1 @@ +Subproject commit 33b4a621a697a305bc3a7610d290677b96beb181 diff --git a/cpp/rasterizer_cuda.cu b/cpp/rasterizer_cuda.cu new file mode 100644 index 0000000..9342ccf --- /dev/null +++ b/cpp/rasterizer_cuda.cu @@ -0,0 +1,158 @@ +#include "rasterizer_kernels.cuh" +#include + +#include +#include + +#include +#include +#include +#include +#include + +void sort_by_keys(int32_t* indices, int64_t* keys, const int num_of_gaussians) +{ + auto device_indices = thrust::device_pointer_cast(indices); + auto device_keys = thrust::device_pointer_cast(keys); + + thrust::sort_by_key(device_keys, device_keys + num_of_gaussians, device_indices); +} + +thrust::device_vector identify_tile_ranges(const int64_t* keys, const int image_width, const int image_height, + const int num_of_gaussians) +{ + const int threads = 768; + const int blocks = (num_of_gaussians + threads - 1) / threads; + const int tiles_width = (image_width + TILE_SIZE - 1) / TILE_SIZE; + const int tiles_height = (image_height + TILE_SIZE - 1) / TILE_SIZE; + const int size = tiles_width * tiles_height + 1; + + thrust::device_vector indices(size, 0); + indices[size - 1] = num_of_gaussians; + + auto indices_device_ptr = thrust::raw_pointer_cast(indices.data()); + + identify_tile_ranges_kernel<<>>(keys, indices_device_ptr, num_of_gaussians); + + thrust::inclusive_scan(indices.begin(), indices.end(), indices.begin(), thrust::maximum{}); + + return indices; +} + +void rasterizer_cuda_forward(Gaussian* gaussian_tensor, const int32_t* indices, const int32_t* tile_indices, + int image_width, int image_height, float* output_image) +{ + dim3 block_dim(TILE_SIZE, TILE_SIZE); + const dim3 grid_dim((image_width + block_dim.x - 1) / block_dim.x, (image_height + block_dim.y - 1) / block_dim.y); + render_image_kernel<<>>(gaussian_tensor, indices, tile_indices, output_image, image_width, + image_height); +} + +std::tuple, thrust::device_vector> +duplicate_keys(int32_t* block_sums, Gaussian* gaussian_tensor, const int32_t* radii, int numBlocks_in_preprocess, + int threads_in_preprocess, int num_of_gaussians, int image_width, int image_height) +{ + // reuse numBlocks and threads count for calculate_keys_and_indices_kernel, so we are sure + // restoring global exclusive sum will work as intended + + auto block_sums_device_ptr = thrust::device_pointer_cast(block_sums); + // calculate exclusive sum per each block + thrust::exclusive_scan(block_sums_device_ptr, block_sums_device_ptr + numBlocks_in_preprocess, + block_sums_device_ptr); + + int32_t last_element = 0; + if(numBlocks_in_preprocess > 0) + cudaMemcpy(&last_element, block_sums + numBlocks_in_preprocess - 1, sizeof(int32_t), cudaMemcpyDeviceToHost); + + const size_t offset_to_last_size_occupied_tiles = + (num_of_gaussians - 1) * sizeof(Gaussian) + offsetof(Gaussian, size_occupied_tiles); + int32_t total_size = 0; + if(num_of_gaussians > 0) + cudaMemcpy(&total_size, reinterpret_cast(gaussian_tensor) + offset_to_last_size_occupied_tiles, + sizeof(int32_t), cudaMemcpyDeviceToHost); + total_size += (numBlocks_in_preprocess > 0) ? last_element : 0; + + thrust::device_vector keys(total_size); + thrust::device_vector indices(total_size); + calculate_keys_and_indices_kernel<<>>( + gaussian_tensor, block_sums, radii, thrust::raw_pointer_cast(keys.data()), + thrust::raw_pointer_cast(indices.data()), num_of_gaussians, total_size, image_width, image_height); + + return std::make_tuple(keys, indices); +} + +std::tuple, thrust::device_vector, thrust::device_vector, int, int> +rasterizer_forward_preprocessing(const float* means_3d, const float* shs, const float* colors_precomp, + const float* opacities, const float* scales, const float* proj_matrix, + const float* view_matrix, const float* camera_position, const float* rotations, + const float* cov3D_precomp, int image_width, int image_height, float tan_fovx, + float tan_fovy, const float scale_modifier, const int max_sh_degree, + const int sh_degree, int num_of_gaussians) +{ + const int max_coeff = (max_sh_degree + 1) * (max_sh_degree + 1); + const int coeff = (sh_degree + 1) * (sh_degree + 1); + + if(shs != nullptr) + { + assert(sh_degree >= 0 && sh_degree <= 4); + // assert(shs.size(1) >= coeff); // todo pass proper size to the method + assert(max_sh_degree == 3); // @TODO other max degrees not supported for now (they need separate cuda kenrels) + } + + const int threads = 768; + const int numBlocks = (num_of_gaussians + threads - 1) / threads; + + const int size_of_gaussian = sizeof(Gaussian); + + thrust::device_vector radii(num_of_gaussians); + thrust::device_vector gaussian_tensor(num_of_gaussians); + Gaussian* gaussian_data_ptr = thrust::raw_pointer_cast(gaussian_tensor.data()); + + const auto focal_y = image_height / (2 * tan_fovy); + const auto focal_x = image_width / (2 * tan_fovx); + const auto c_x = image_width / 2.0f; + const auto c_y = image_height / 2.0f; + + // here we will keep all sizes sum per each block + thrust::device_vector block_sums(numBlocks); + + rasterizer_preprocessing_kernel<<>>( + means_3d, shs, scales, rotations, proj_matrix, view_matrix, camera_position, opacities, cov3D_precomp, + colors_precomp, scale_modifier, max_coeff, sh_degree, num_of_gaussians, focal_x, focal_y, tan_fovx, tan_fovy, + c_x, c_y, image_width, image_height, thrust::raw_pointer_cast(radii.data()), + thrust::raw_pointer_cast(block_sums.data()), gaussian_data_ptr); + + return std::make_tuple(radii, gaussian_tensor, block_sums, numBlocks, threads); +} + +void rasterizer_forward_core_deepsense(const float* means_3d, const float* shs, const float* opacities, + const float* scales, const float* rotations, int num_of_gaussians, + const float* view_matrix, const float* proj_matrix, const float* camera_position, + int image_width, int image_height, float tan_fovx, float tan_fovy, + float scale_modifier, const int max_sh_degree, const int sh_degree, + float* output_image) +{ + + const float* colors_precomp = nullptr; + const float* cov3D_precomp = nullptr; + + auto [radii, gaussian_tensor, block_sums, num_blocks_preprocessing, num_threads_preprocessing] = + rasterizer_forward_preprocessing(means_3d, shs, colors_precomp, opacities, scales, proj_matrix, view_matrix, + camera_position, rotations, cov3D_precomp, image_width, image_height, tan_fovx, + tan_fovy, scale_modifier, max_sh_degree, sh_degree, num_of_gaussians); + + auto [keys, indices] = + duplicate_keys(thrust::raw_pointer_cast(block_sums.data()), thrust::raw_pointer_cast(gaussian_tensor.data()), + thrust::raw_pointer_cast(radii.data()), num_blocks_preprocessing, num_threads_preprocessing, + num_of_gaussians, image_width, image_height); + + auto indices_ptr = thrust::raw_pointer_cast(indices.data()); + auto keys_ptr = thrust::raw_pointer_cast(keys.data()); + sort_by_keys(indices_ptr, keys_ptr, num_of_gaussians); + + auto tile_indices = identify_tile_ranges(keys_ptr, image_width, image_height, num_of_gaussians); + auto tile_indices_raw_ptr = thrust::raw_pointer_cast(tile_indices.data()); + + rasterizer_cuda_forward(thrust::raw_pointer_cast(gaussian_tensor.data()), indices_ptr, tile_indices_raw_ptr, + image_width, image_height, output_image); +} diff --git a/cpp/rasterizer_kernels.cu b/cpp/rasterizer_kernels.cu new file mode 100644 index 0000000..9efcc3d --- /dev/null +++ b/cpp/rasterizer_kernels.cu @@ -0,0 +1,675 @@ +#include "rasterizer_kernels.cuh" +#include +#include + +namespace cg = cooperative_groups; + +__device__ int32_t is_gaussian_in_frustum(const float3& means_3d, const float* projection_matrix) +{ + const auto inv_w = 1.f / (projection_matrix[12] * means_3d.x + projection_matrix[13] * means_3d.y + + projection_matrix[14] * means_3d.z + projection_matrix[15]); + + const auto x = inv_w * (projection_matrix[0] * means_3d.x + projection_matrix[1] * means_3d.y + + projection_matrix[2] * means_3d.z + projection_matrix[3]); + + const auto y = inv_w * (projection_matrix[4] * means_3d.x + projection_matrix[5] * means_3d.y + + projection_matrix[6] * means_3d.z + projection_matrix[7]); + + const auto z = inv_w * (projection_matrix[8] * means_3d.x + projection_matrix[9] * means_3d.y + + projection_matrix[10] * means_3d.z + projection_matrix[11]); + + const auto limits = 1.f; + + return x > -limits && x < limits && y > -limits && y < limits && z > 0.1f; +} + +__device__ float3 switch_shs_func(const float* features, const float3& means3d, int deg, const float3& camera_center, + int index_features, int active_sh_degree) +{ + if(active_sh_degree == 0) + return calc_shs_deg_0(features, means3d, index_features); + + if(active_sh_degree == 1) + return calc_shs_deg_1(features, means3d, camera_center, index_features); + + if(active_sh_degree == 2) + return calc_shs_deg_2(features, means3d, camera_center, index_features); + + if(active_sh_degree == 3) + return calc_shs_deg_3(features, means3d, camera_center, index_features); + + return {0.0f, 0.0f, 0.0f}; +} + +__device__ float6 make_float6(float x, float y, float z, float w, float v, float u) +{ + return {x, y, z, w, v, u}; +} + +__device__ float3 calculate_conic(const float3& cov_2d) +{ + const auto det = cov_2d.x * cov_2d.z - cov_2d.y * cov_2d.y; + const auto det_inv = 1.0f / (det + 1e-6f); + + return {cov_2d.z * det_inv, -cov_2d.y * det_inv, cov_2d.x * det_inv}; +} + +__global__ void rasterizer_preprocessing_kernel( + const float* __restrict__ means3d, const float* features, const float* __restrict__ scales, + const float* __restrict__ rotations, const float* __restrict__ projection_matrix, + const float* __restrict__ view_matrix, const float* __restrict__ camera_center, const float* __restrict__ opacities, + const float* __restrict__ cov3D_precomp, const float* __restrict__ colors_precomp, float scale_modifier, + int max_coeff, int active_sh_degree, int num_of_gaussians, float focal_x, float focal_y, float tan_fovx, + float tan_fovy, float c_x, float c_y, unsigned int img_width, int img_height, int32_t* __restrict__ radii, + int* __restrict__ block_sums, Gaussian* __restrict__ gaussians, float* __restrict__ cov_3d_output) +{ + __shared__ float _projection_matrix[16]; + __shared__ float _view_matrix[12]; + __shared__ float _camera_center[3]; + extern __shared__ int shared_inclusive_sum[]; + + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int bdim = blockDim.x; + const int index = blockIdx.x * blockDim.x + threadIdx.x; + const int index_4 = index * 4; + const int index_3 = index_4 - index; // index*3 + const int index_6 = index * 6; + int index_features = max_coeff * index_3; + + if(threadIdx.x < 16) + _projection_matrix[threadIdx.x] = projection_matrix[threadIdx.x]; + if(threadIdx.x < 12) + _view_matrix[threadIdx.x] = view_matrix[threadIdx.x]; + if(threadIdx.x < 3) + _camera_center[threadIdx.x] = camera_center[threadIdx.x]; + + __syncthreads(); + + if(index < num_of_gaussians) + { + const float6 _cov3D = + cov3D_precomp == nullptr + ? calc_cov3d(&scales[index_3], &rotations[index_4], scale_modifier) + : make_float6(cov3D_precomp[index_6], cov3D_precomp[index_6 + 1], cov3D_precomp[index_6 + 2], + cov3D_precomp[index_6 + 3], cov3D_precomp[index_6 + 4], cov3D_precomp[index_6 + 5]); + + if(cov_3d_output && cov3D_precomp == nullptr) + { + cov_3d_output[index_6] = _cov3D.x; + cov_3d_output[index_6 + 1] = _cov3D.y; + cov_3d_output[index_6 + 2] = _cov3D.z; + cov_3d_output[index_6 + 3] = _cov3D.w; + cov_3d_output[index_6 + 4] = _cov3D.v; + cov_3d_output[index_6 + 5] = _cov3D.u; + } + + const float3 _means_3d = float3{means3d[index_3], means3d[index_3 + 1], means3d[index_3 + 2]}; + const float3 _campos = float3{_camera_center[0], _camera_center[1], _camera_center[2]}; + + const float3 colors = + (colors_precomp == nullptr) + ? switch_shs_func(features, _means_3d, active_sh_degree, _campos, index_features, active_sh_degree) + : make_float3(colors_precomp[index_3], colors_precomp[index_3 + 1], colors_precomp[index_3 + 2]); + + int32_t mask_radii = is_gaussian_in_frustum(_means_3d, _projection_matrix); + float3 _means_2d = project_point(_means_3d, _view_matrix); + float3 _cov_2d = + calculate_cov_2d(focal_x, focal_y, tan_fovx, tan_fovy, _means_2d, _view_matrix, _cov3D, num_of_gaussians); + const float2 _means_2d_update = project_to_image(_means_2d, focal_x, focal_y, c_x, c_y, num_of_gaussians); + + _means_2d.x = _means_2d_update.x; + _means_2d.y = _means_2d_update.y; + + mask_radii = calculate_radii(_cov_2d, mask_radii); + + const int _size = calculate_sizes(_means_2d, mask_radii, img_width, img_height); + radii[index] = mask_radii; // global mem write needed for keys preparation next kernel + + const float _opacity = opacities[index]; + + shared_inclusive_sum[tid] = (tid < num_of_gaussians) ? _size : 0; // todo simplify + __syncthreads(); + + // Intra-block inclusive scan using shared memory + for(int d = 1; d < bdim; d *= 2) + { + int t = (tid >= d) ? shared_inclusive_sum[tid - d] : 0; + __syncthreads(); + if(tid >= d) + shared_inclusive_sum[tid] += t; + __syncthreads(); + } + + const float3 conic = calculate_conic(_cov_2d); + + gaussians[index].means_2d = _means_2d; + gaussians[index].conic = conic; + gaussians[index].color = colors; + gaussians[index].alpha = _opacity; + gaussians[index].size_occupied_tiles = (tid > 0) ? shared_inclusive_sum[tid - 1] : 0; // local exclusive sum + gaussians[index].size_occupied_tiles_next = + shared_inclusive_sum[tid]; // local inclusive sum (or next thread local exlcusive sum) + } + + if(tid == bdim - 1) + { + // write entire block inclusive sum; we will use it to calculate exclusive sum between blocks + // and later restore global exclusive sum + block_sums[bid] = shared_inclusive_sum[tid]; + } +} + +__device__ int32_t calculate_radii(const float3& cov2d, int32_t _radii) +{ + const float det = cov2d.x * cov2d.z - cov2d.y * cov2d.y; + const float mid = 0.5f * (cov2d.x + cov2d.z); + const float diff = mid * mid - det; + const float lambda1 = mid + sqrt(max(0.1f, diff)); + const float lambda2 = mid - sqrt(max(0.1f, diff)); + + return ceilf(3 * sqrtf(max(lambda1, lambda2))) * (det > 1e-6f && _radii > 0); +} + +__device__ float3 calculate_cov_2d(float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float3& means_2d, + const float* __restrict__ view_matrix, const float6& cov3D, int num_points) +{ + const float limx = 1.3f * tan_fovx; + const float limy = 1.3f * tan_fovy; + const float txtz = means_2d.x / means_2d.z; + const float tytz = means_2d.y / means_2d.z; + const float t[3] = {min(limx, max(-limx, txtz)) * means_2d.z, min(limy, max(-limy, tytz)) * means_2d.z, means_2d.z}; + + const float J[3][3] = {{focal_x / t[2], 0, -(focal_x * t[0]) / (t[2] * t[2])}, + {0, focal_y / t[2], -(focal_y * t[1]) / (t[2] * t[2])}, + {0, 0, 0}}; + + const float W[3][3] = {{view_matrix[0], view_matrix[4], view_matrix[8]}, + {view_matrix[1], view_matrix[5], view_matrix[9]}, + {view_matrix[2], view_matrix[6], view_matrix[10]}}; + + float T[3][3] = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; + for(int i = 0; i < 3; ++i) + { + for(int j = 0; j < 3; ++j) + { + for(int k = 0; k < 3; ++k) + { + T[i][j] += W[i][k] * J[j][k]; + } + } + } + + const float Vrk[3][3] = {{cov3D.x, cov3D.y, cov3D.z}, {cov3D.y, cov3D.w, cov3D.v}, {cov3D.z, cov3D.v, cov3D.u}}; + + float cov[2][2] = {{0, 0}, {0, 0}}; + float temp[3][3] = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; + + for(int i = 0; i < 3; ++i) + { + for(int j = 0; j < 3; ++j) + { + for(int k = 0; k < 3; ++k) + { + temp[i][j] += T[k][i] * Vrk[k][j]; + } + } + } + + for(int i = 0; i < 2; ++i) + { + for(int j = 0; j < 2; ++j) + { + for(int k = 0; k < 3; ++k) + { + cov[i][j] += temp[i][k] * T[k][j]; + } + } + } + + return {cov[0][0] + 0.3f, cov[0][1], cov[1][1] + 0.3f}; +} + +__device__ int calculate_sizes(const float3& xyd, int radii, unsigned int img_width, int img_height) +{ + const auto x = xyd.x; + const auto y = xyd.y; + const auto radius = radii; + + const float2 center = {x / TILE_SIZE, y / TILE_SIZE}; + const float2 dims = {radius / (float)TILE_SIZE, radius / (float)TILE_SIZE}; + const dim3 img_size = {static_cast((img_width + TILE_SIZE - 1) / TILE_SIZE), + static_cast((img_height + TILE_SIZE - 1) / TILE_SIZE)}; + + const int start_x = min(max(0, (int)(center.x - dims.x)), img_size.x); + const int end_x = min(max(0, (int)(center.x + dims.x + 1)), img_size.x); + const int start_y = min(max(0, (int)(center.y - dims.y)), img_size.y); + const int end_y = min(max(0, (int)(center.y + dims.y + 1)), img_size.y); + + return (end_x - start_x) * (end_y - start_y) * (radius > 0); +} + +__global__ void calculate_keys_and_indices_kernel(const Gaussian* __restrict__ gaussians, + const int32_t* __restrict__ block_sums, + const int32_t* __restrict__ radii, int64_t* __restrict__ keys, + int32_t* __restrict__ indices, int num_ndc, int num_indices, + unsigned int img_width, unsigned int img_height) +{ + const int index = blockIdx.x * blockDim.x + threadIdx.x; + const int bid = blockIdx.x; + + // store exclusive sums for this thread and next thread. + // Next thread block sums might (and probably is) different if current thread is last thread in the block + __shared__ int blocks_prefix[1]; + + if(threadIdx.x == 0) + { + blocks_prefix[0] = block_sums[bid]; + } + __syncthreads(); + + if(index < num_ndc) + { + Gaussian gaussian = gaussians[index]; + // restore global exclusive sum for start and end indices + auto start_index = gaussian.size_occupied_tiles + blocks_prefix[0]; + auto end_index = (index + 1) < num_ndc ? gaussian.size_occupied_tiles_next + blocks_prefix[0] : num_indices; + + if(end_index - start_index > 0) + { + const auto x = gaussian.means_2d.x; + const auto y = gaussian.means_2d.y; + const auto radius = radii[index]; + const int64_t depth = (int64_t) * (int32_t*)&(gaussian.means_2d.z); + const int tiles_per_width = (img_width + TILE_SIZE - 1) / TILE_SIZE; + + const float2 center = {x / TILE_SIZE, y / TILE_SIZE}; + const float2 dims = {radius / (float)TILE_SIZE, radius / (float)TILE_SIZE}; + const dim3 img_size = {(img_width + TILE_SIZE - 1) / TILE_SIZE, (img_height + TILE_SIZE - 1) / TILE_SIZE}; + const int start_x = min(max(0, (int)(center.x - dims.x)), img_size.x); + const int end_x = min(max(0, (int)(center.x + dims.x + 1)), img_size.x); + const int start_y = min(max(0, (int)(center.y - dims.y)), img_size.y); + const int end_y = min(max(0, (int)(center.y + dims.y + 1)), img_size.y); + + int local_index = start_index; + for(int tile_y = start_y; tile_y < end_y; ++tile_y) + { + for(int tile_x = start_x; tile_x < end_x; ++tile_x, ++local_index) + { + const int64_t tile_index = tile_y * tiles_per_width + tile_x; + keys[local_index] = (tile_index << 32) | depth; + indices[local_index] = index; + } + } + } + } +} + +__global__ void identify_tile_ranges_kernel(const int64_t* __restrict__ keys, int32_t* __restrict__ tile_indices, + int size) +{ + const int index = blockIdx.x * blockDim.x + threadIdx.x; + + if(index < size - 1) + { + const int64_t current_key = keys[index]; + const int64_t next_key = keys[index + 1]; + const int32_t current_tile = current_key >> 32; + const int32_t next_tile = next_key >> 32; + + if(current_tile != next_tile) + tile_indices[current_tile + 1] = index + 1; + } +} + +__device__ float3 project_point(const float3& points, const float* __restrict__ projection) +{ + return {projection[0] * points.x + projection[1] * points.y + projection[2] * points.z + projection[3], + projection[4] * points.x + projection[5] * points.y + projection[6] * points.z + projection[7], + projection[8] * points.x + projection[9] * points.y + projection[10] * points.z + projection[11]}; +} + +__device__ float2 project_to_image(float3& points, float fx, float fy, float cx, float cy, int numPoints) +{ + const auto z_inv = 1.f / (points.z + 1e-6f); + + return {fx * points.x * z_inv + cx, fy * points.y * z_inv + cy}; +} + +static constexpr float alpha_threshold = 1.0f / 255.0f; + +__global__ void render_image_kernel(const Gaussian* __restrict__ gaussians_global, const int32_t* __restrict__ indices, + const int32_t* __restrict__ tile_indices, float* __restrict__ output_image, + int image_width, int image_height, float* __restrict__ final_Ts, + int32_t* __restrict__ final_idx) +{ + const int x = blockIdx.x * blockDim.x + threadIdx.x; + const int y = blockIdx.y * blockDim.y + threadIdx.y; + + __shared__ Gaussian gaussians[num_of_gaussians_in_shared_memory]; + + const auto tile_x = blockIdx.x; + const auto tile_y = blockIdx.y; + const auto tile_index = tile_y * gridDim.x + tile_x; + const auto start_index = tile_indices[tile_index]; + const auto end_index = tile_indices[tile_index + 1]; + const auto num_of_gaussians = end_index - start_index; + + float4 blendedColor = make_float4(0.0f, 0.0f, 0.0f, 1.0f); + int32_t curr_idx = 0; + + const int num_of_syncs = (num_of_gaussians + num_of_threads_in_block - 1) / num_of_threads_in_block; + + bool is_ok = true; + for(int s = 0; s < num_of_syncs; ++s) + { + if(__syncthreads_count(!is_ok) >= num_of_threads_in_block) + { + break; + } + + const int thread_idx_in_block = threadIdx.y * blockDim.x + threadIdx.x; + Gaussian& gaussian_to_init = gaussians[thread_idx_in_block]; + const int current_index_for_thread = start_index + s * num_of_gaussians_in_shared_memory + thread_idx_in_block; + + if(current_index_for_thread < end_index) + { + const auto gaussian_index = indices[current_index_for_thread]; + gaussian_to_init = gaussians_global[gaussian_index]; + } + + __syncthreads(); + + const int gaussians_in_current_sync = + min(num_of_gaussians_in_shared_memory, num_of_gaussians - s * num_of_gaussians_in_shared_memory); + + for(int i = 0; i < gaussians_in_current_sync && is_ok; ++i) + { + const Gaussian& gaussian = gaussians[i]; + const float3& conic = gaussian.conic; + const float2 d = {x + 0.5f - gaussian.means_2d.x, y + 0.5f - gaussian.means_2d.y}; + const float power = -0.5f * (conic.x * d.x * d.x + conic.z * d.y * d.y) - conic.y * d.x * d.y; + const float alpha = min(0.999f, gaussian.alpha * expf(power)); + + if(power >= 0.f || alpha < alpha_threshold) + { + continue; + } + + const float T = (1.0f - alpha) * blendedColor.w; + + is_ok = T >= 0.0001f; + + if(is_ok) + { + const float vis = alpha * blendedColor.w; + + blendedColor.x += vis * gaussian.color.x; + blendedColor.y += vis * gaussian.color.y; + blendedColor.z += vis * gaussian.color.z; + blendedColor.w = T; + + curr_idx = start_index + s * num_of_gaussians_in_shared_memory + i; + } + } + } + + if(x < image_width && y < image_height) + { + const auto pixel_idx = 3 * (y * image_width + x); + output_image[pixel_idx] = blendedColor.x; + output_image[pixel_idx + 1] = blendedColor.y; + output_image[pixel_idx + 2] = blendedColor.z; + + if(final_Ts && final_idx) + { + const auto final_index = y * image_width + x; + final_Ts[final_index] = blendedColor.w; + final_idx[final_index] = curr_idx; + } + } +} + +__device__ void add_shs_deg_0(const float* __restrict__ features, float3& result) +{ + const float sh0_R = features[0]; + const float sh0_G = features[1]; + const float sh0_B = features[2]; + + const float C0 = 0.28209479177387814f; + result.x += C0 * sh0_R; + result.y += C0 * sh0_G; + result.z += C0 * sh0_B; +} + +__device__ void add_shs_deg_1(const float* __restrict__ features, const float3& viewdir, float3& result) +{ + const float sh1_R = features[3]; + const float sh1_G = features[4]; + const float sh1_B = features[5]; + const float sh2_R = features[6]; + const float sh2_G = features[7]; + const float sh2_B = features[8]; + const float sh3_R = features[9]; + const float sh3_G = features[10]; + const float sh3_B = features[11]; + + const float C1 = 0.4886025119029199f; + result.x += C1 * (-viewdir.y * sh1_R + viewdir.z * sh2_R - viewdir.x * sh3_R); + result.y += C1 * (-viewdir.y * sh1_G + viewdir.z * sh2_G - viewdir.x * sh3_G); + result.z += C1 * (-viewdir.y * sh1_B + viewdir.z * sh2_B - viewdir.x * sh3_B); +} + +__device__ void add_shs_deg_2(const float* __restrict__ features, const float3& viewdir, float3& result) +{ + const float sh4_R = features[12]; + const float sh4_G = features[13]; + const float sh4_B = features[14]; + const float sh5_R = features[15]; + const float sh5_G = features[16]; + const float sh5_B = features[17]; + const float sh6_R = features[18]; + const float sh6_G = features[19]; + const float sh6_B = features[20]; + const float sh7_R = features[21]; + const float sh7_G = features[22]; + const float sh7_B = features[23]; + const float sh8_R = features[24]; + const float sh8_G = features[25]; + const float sh8_B = features[26]; + + const float xx = viewdir.x * viewdir.x; + const float yy = viewdir.y * viewdir.y; + const float zz = viewdir.z * viewdir.z; + const float xy = viewdir.x * viewdir.y; + const float xz = viewdir.x * viewdir.z; + const float yz = viewdir.y * viewdir.z; + + const float C2_0 = 1.0925484305920792f; + const float C2_1 = -1.0925484305920792f; + const float C2_2 = 0.31539156525252005f; + const float C2_3 = -1.0925484305920792f; + const float C2_4 = 0.5462742152960396f; + + result.x += C2_0 * xy * sh4_R + C2_1 * yz * sh5_R + C2_2 * (2.0f * zz - xx - yy) * sh6_R + C2_3 * xz * sh7_R + + C2_4 * (xx - yy) * sh8_R; + result.y += C2_0 * xy * sh4_G + C2_1 * yz * sh5_G + C2_2 * (2.0f * zz - xx - yy) * sh6_G + C2_3 * xz * sh7_G + + C2_4 * (xx - yy) * sh8_G; + result.z += C2_0 * xy * sh4_B + C2_1 * yz * sh5_B + C2_2 * (2.0f * zz - xx - yy) * sh6_B + C2_3 * xz * sh7_B + + C2_4 * (xx - yy) * sh8_B; +} + +__device__ void add_shs_deg_3(const float* __restrict__ features, const float3& viewdir, float3& result) +{ + const float sh9_R = features[27]; + const float sh9_G = features[28]; + const float sh9_B = features[29]; + const float sh10_R = features[30]; + const float sh10_G = features[31]; + const float sh10_B = features[32]; + const float sh11_R = features[33]; + const float sh11_G = features[34]; + const float sh11_B = features[35]; + const float sh12_R = features[36]; + const float sh12_G = features[37]; + const float sh12_B = features[38]; + const float sh13_R = features[39]; + const float sh13_G = features[40]; + const float sh13_B = features[41]; + const float sh14_R = features[42]; + const float sh14_G = features[43]; + const float sh14_B = features[44]; + const float sh15_R = features[45]; + const float sh15_G = features[46]; + const float sh15_B = features[47]; + + const float xx = viewdir.x * viewdir.x; + const float yy = viewdir.y * viewdir.y; + const float zz = viewdir.z * viewdir.z; + const float xy = viewdir.x * viewdir.y; + + const float C3_0 = -0.5900435899266435f; + const float C3_1 = 2.890611442640554f; + const float C3_2 = -0.4570457994644658f; + const float C3_3 = 0.3731763325901154f; + const float C3_4 = -0.4570457994644658f; + const float C3_5 = 1.445305721320277f; + const float C3_6 = -0.5900435899266435f; + + result.x += C3_0 * viewdir.y * (3 * xx - yy) * sh9_R + C3_1 * xy * viewdir.z * sh10_R + + C3_2 * viewdir.y * (4 * zz - xx - yy) * sh11_R + + C3_3 * viewdir.z * (2 * zz - 3 * xx - 3 * yy) * sh12_R + + C3_4 * viewdir.x * (4 * zz - xx - yy) * sh13_R + C3_5 * viewdir.z * (xx - yy) * sh14_R + + C3_6 * viewdir.x * (xx - 3 * yy) * sh15_R; + result.y += C3_0 * viewdir.y * (3 * xx - yy) * sh9_G + C3_1 * xy * viewdir.z * sh10_G + + C3_2 * viewdir.y * (4 * zz - xx - yy) * sh11_G + + C3_3 * viewdir.z * (2 * zz - 3 * xx - 3 * yy) * sh12_G + + C3_4 * viewdir.x * (4 * zz - xx - yy) * sh13_G + C3_5 * viewdir.z * (xx - yy) * sh14_G + + C3_6 * viewdir.x * (xx - 3 * yy) * sh15_G; + result.z += C3_0 * viewdir.y * (3 * xx - yy) * sh9_B + C3_1 * xy * viewdir.z * sh10_B + + C3_2 * viewdir.y * (4 * zz - xx - yy) * sh11_B + + C3_3 * viewdir.z * (2 * zz - 3 * xx - 3 * yy) * sh12_B + + C3_4 * viewdir.x * (4 * zz - xx - yy) * sh13_B + C3_5 * viewdir.z * (xx - yy) * sh14_B + + C3_6 * viewdir.x * (xx - 3 * yy) * sh15_B; +} + +__device__ float3 calc_shs_deg_0(const float* __restrict__ features, const float3& means3d, int index_features) +{ + float3 result = {0.0f, 0.0f, 0.0f}; + + add_shs_deg_0(&features[index_features], result); + + return {fmaxf(result.x + 0.5f, 0.0f), fmaxf(result.y + 0.5f, 0.0f), fmaxf(result.z + 0.5f, 0.0f)}; +} + +__device__ float3 calculate_viewdir(const float3& means_3d, const float3& camera_center) +{ + const float3 viewdir = {means_3d.x - camera_center.x, means_3d.y - camera_center.y, means_3d.z - camera_center.z}; + const float inv_norm = rsqrtf(viewdir.x * viewdir.x + viewdir.y * viewdir.y + viewdir.z * viewdir.z); + + return {viewdir.x * inv_norm, viewdir.y * inv_norm, viewdir.z * inv_norm}; +} + +__device__ float3 calc_shs_deg_1(const float* __restrict__ features, const float3& means_3d, + const float3& camera_center, int index_features) +{ + const auto viewdir = calculate_viewdir(means_3d, camera_center); + + float3 result = {0.0f, 0.0f, 0.0f}; + + add_shs_deg_0(&features[index_features], result); + add_shs_deg_1(&features[index_features], viewdir, result); + + return {fmaxf(result.x + 0.5f, 0.0f), fmaxf(result.y + 0.5f, 0.0f), fmaxf(result.z + 0.5f, 0.0f)}; +} + +__device__ float3 calc_shs_deg_2(const float* __restrict__ features, const float3& means_3d, + const float3& camera_center, int index_features) +{ + const auto viewdir = calculate_viewdir(means_3d, camera_center); + + float3 result = {0.0f, 0.0f, 0.0f}; + + add_shs_deg_0(&features[index_features], result); + add_shs_deg_1(&features[index_features], viewdir, result); + add_shs_deg_2(&features[index_features], viewdir, result); + + return {fmaxf(result.x + 0.5f, 0.0f), fmaxf(result.y + 0.5f, 0.0f), fmaxf(result.z + 0.5f, 0.0f)}; +} + +__device__ float3 calc_shs_deg_3(const float* __restrict__ features, const float3& means_3d, + const float3& camera_center, int index_features) +{ + const auto viewdir = calculate_viewdir(means_3d, camera_center); + + float3 result = {0.0f, 0.0f, 0.0f}; + + add_shs_deg_0(&features[index_features], result); + add_shs_deg_1(&features[index_features], viewdir, result); + add_shs_deg_2(&features[index_features], viewdir, result); + add_shs_deg_3(&features[index_features], viewdir, result); + + return {fmaxf(result.x + 0.5f, 0.0f), fmaxf(result.y + 0.5f, 0.0f), fmaxf(result.z + 0.5f, 0.0f)}; +} + +__device__ float6 calc_cov3d(const float* __restrict__ scales, const float* rotations, float scale_modifier) +{ + float r = rotations[0]; + float x = rotations[1]; + float y = rotations[2]; + float z = rotations[3]; + + // normalize + const float inv_norm = rsqrtf(r * r + x * x + y * y + z * z); + r = r * inv_norm; + x = x * inv_norm; + y = y * inv_norm; + z = z * inv_norm; + + const float3 scale = {scales[0] * scale_modifier, scales[1] * scale_modifier, scales[2] * scale_modifier}; + + const float r00 = 1.0f - 2.0f * (y * y + z * z); + const float r01 = 2.0f * (x * y - r * z); + const float r02 = 2.0f * (x * z + r * y); + const float r10 = 2.0f * (x * y + r * z); + const float r11 = 1.0f - 2.0f * (x * x + z * z); + const float r12 = 2.0f * (y * z - r * x); + const float r20 = 2.0f * (x * z - r * y); + const float r21 = 2.0f * (y * z + r * x); + const float r22 = 1.0f - 2.0f * (x * x + y * y); + + const float m00 = scale.x * r00; + const float m01 = scale.y * r01; + const float m02 = scale.z * r02; + const float m10 = scale.x * r10; + const float m11 = scale.y * r11; + const float m12 = scale.z * r12; + const float m20 = scale.x * r20; + const float m21 = scale.y * r21; + const float m22 = scale.z * r22; + + const float cov00 = m00 * m00 + m01 * m01 + m02 * m02; + const float cov01 = m00 * m10 + m01 * m11 + m02 * m12; + const float cov02 = m00 * m20 + m01 * m21 + m02 * m22; + const float cov11 = m10 * m10 + m11 * m11 + m12 * m12; + const float cov12 = m10 * m20 + m11 * m21 + m12 * m22; + const float cov22 = m20 * m20 + m21 * m21 + m22 * m22; + + return {cov00, cov01, cov02, cov11, cov12, cov22}; +} + +__global__ void tensors_from_gaussians_kernel(const Gaussian* __restrict__ gaussian_data_ptr, float2* __restrict__ xy, + float3* __restrict__ colors, float3* __restrict__ conic, + const int num_of_gaussians) +{ + const auto index = cg::this_grid().thread_rank(); + + if(index < num_of_gaussians) + { + const auto& gaussian = gaussian_data_ptr[index]; + + xy[index] = {gaussian.means_2d.x, gaussian.means_2d.y}; + conic[index] = gaussian.conic; + colors[index] = gaussian.color; + } +} diff --git a/cpp/rasterizer_kernels.cuh b/cpp/rasterizer_kernels.cuh new file mode 100644 index 0000000..db4b041 --- /dev/null +++ b/cpp/rasterizer_kernels.cuh @@ -0,0 +1,75 @@ +#ifndef RASTERIZER_KERNELS_CUH +#define RASTERIZER_KERNELS_CUH + +#include + +constexpr auto TILE_SIZE = 16; +constexpr auto num_of_threads_in_block = TILE_SIZE * TILE_SIZE; +constexpr auto num_of_gaussians_in_shared_memory = num_of_threads_in_block; + +__global__ void check_gaussians_in_frustum(const float* means_3d, const float* scales_3d, int32_t* results, + const float* projection, int num_gaussians); + +struct __align__(16) float6 +{ + float x, y, z, w, v, u; +}; + +struct __align__(16) Gaussian +{ + float3 means_2d; // 0-12 + float3 conic; // 12-24 + float3 color; // 24-36 + float alpha; // 36-40 + int size_occupied_tiles; // 40-44 + int size_occupied_tiles_next; // 4 padding to 48 +}; + +__device__ float6 calc_cov3d(const float* __restrict__ scales, const float* __restrict__ rotations, + float scale_modifier); +__device__ int32_t calculate_radii(const float3& cov2d, int32_t _radii); +__device__ float3 calc_shs_deg_0(const float* features, const float3& means3d, int index_features); +__device__ float3 calc_shs_deg_1(const float* features, const float3& means3d, const float3& camera_center, + int index_features); +__device__ float3 calc_shs_deg_2(const float* features, const float3& means3d, const float3& camera_center, + int index_features); +__device__ float3 calc_shs_deg_3(const float* features, const float3& means3d, const float3& camera_center, + int index_features); +__device__ float3 project_point(const float3& points, const float* __restrict__ projection); +__device__ int32_t is_gaussian_in_frustum(const float3& means_3d, const float* projection_matrix); +__device__ float3 calculate_cov_2d(float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float3& means_2d, + const float* __restrict__ view_matrix, const float6& cov3D, int num_points); +__device__ float2 project_to_image(float3& points, float fx, float fy, float cx, float cy, int numPoints); +__device__ int calculate_sizes(const float3& xyd, int radii, unsigned int img_width, int img_height); + +__device__ float3 switch_shs_func(const float* features, const float3& means3d, int deg, const float3& camera_center, + int index_features, int active_sh_degree); +__device__ float6 make_float6(float x, float y, float z, float w, float v, float u); + +__global__ void rasterizer_preprocessing_kernel( + const float* __restrict__ means3d, const float* features, const float* __restrict__ scales, + const float* __restrict__ rotations, const float* __restrict__ projection_matrix, + const float* __restrict__ view_matrix, const float* __restrict__ camera_center, const float* __restrict__ opacities, + const float* __restrict__ cov3D_precomp, const float* __restrict__ colors_precomp, float scale_modifier, + int max_coeff, int active_sh_degree, int num_of_gaussians, float focal_x, float focal_y, float tan_fovx, + float tan_fovy, float c_x, float c_y, unsigned int img_width, int img_height, int32_t* __restrict__ radii, + int* __restrict__ block_sums, Gaussian* __restrict__ gaussians, float* __restrict__ cov_3d_output = nullptr); + +__global__ void calculate_sizes_kernel(const float* __restrict__ ndc, const int32_t* __restrict__ radii, + int* __restrict__ sizes, unsigned int num_ndc, unsigned int img_width, + int img_height); + +__global__ void calculate_keys_and_indices_kernel(const Gaussian* __restrict__ gaussians, + const int32_t* __restrict__ block_sums, + const int32_t* __restrict__ radii, int64_t* __restrict__ keys, + int32_t* __restrict__ indices, int num_ndc, int num_indices, + unsigned int img_width, unsigned int img_height); + +__global__ void identify_tile_ranges_kernel(const int64_t* __restrict__ keys, int32_t* __restrict__ tile_indices, + int size); +__global__ void render_image_kernel(const Gaussian* __restrict__ gaussians_global, const int32_t* __restrict__ indices, + const int32_t* __restrict__ tile_indices, float* __restrict__ output_image, + int image_width, int image_height, float* __restrict__ final_Ts = nullptr, + int32_t* __restrict__ final_idx = nullptr); + +#endif diff --git a/cpp/rasterizer_python.cpp b/cpp/rasterizer_python.cpp new file mode 100644 index 0000000..54360ff --- /dev/null +++ b/cpp/rasterizer_python.cpp @@ -0,0 +1,23 @@ +#include "ds_cuda_rasterizer/rasterizer_torch.hpp" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + py::class_(m, "RasterizationSettings") + .def(py::init<>()) + .def_readwrite("image_height", &TorchRasterizationSettings::image_height) + .def_readwrite("image_width", &TorchRasterizationSettings::image_width) + .def_readwrite("tanfovx", &TorchRasterizationSettings::tanfovx) + .def_readwrite("tanfovy", &TorchRasterizationSettings::tanfovy) + .def_readwrite("bg", &TorchRasterizationSettings::bg) + .def_readwrite("scale_modifier", &TorchRasterizationSettings::scale_modifier) + .def_readwrite("view_matrix", &TorchRasterizationSettings::view_matrix) + .def_readwrite("proj_matrix", &TorchRasterizationSettings::proj_matrix) + .def_readwrite("sh_degree", &TorchRasterizationSettings::sh_degree) + .def_readwrite("max_sh_degree", &TorchRasterizationSettings::max_sh_degree) + .def_readwrite("campos", &TorchRasterizationSettings::campos) + .def_readwrite("prefiltered", &TorchRasterizationSettings::prefiltered) + .def_readwrite("debug", &TorchRasterizationSettings::debug); + + m.def("forward_deepsense", &rasterizer_forward_deepsense, "rasterizer forward (CUDA)"); + m.def("backward_deepsense", &rasterizer_backward_deepsense, "rasterizer backward (CUDA)"); +} diff --git a/cpp/rasterizer_torch.cu b/cpp/rasterizer_torch.cu new file mode 100644 index 0000000..664f0ac --- /dev/null +++ b/cpp/rasterizer_torch.cu @@ -0,0 +1,316 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "ds_cuda_rasterizer/rasterizer_torch.hpp" +#include "gsplat/bindings.h" +#include "rasterizer_kernels.cuh" + +#define CHECK_FLOAT32(x) TORCH_CHECK(x.dtype() == torch::kFloat32, #x " must be float32") +#define CHECK_FLOAT_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x); \ + CHECK_FLOAT32(x); + +__global__ void tensors_from_gaussians_kernel(const Gaussian* __restrict__ gaussian_data_ptr, float2* __restrict__ xy, + float3* __restrict__ colors, float3* __restrict__ conic, + const int num_of_gaussians); + +torch::Tensor identify_tile_ranges(torch::Tensor keys, const TorchRasterizationSettings& settings) +{ + const int threads = 768; + const int blocks = (keys.sizes()[0] + threads - 1) / threads; + const int tiles_width = (settings.image_width + TILE_SIZE - 1) / TILE_SIZE; + const int tiles_height = (settings.image_height + TILE_SIZE - 1) / TILE_SIZE; + const int size = tiles_width * tiles_height + 1; + + torch::Tensor indices = torch::zeros({size}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + indices[size - 1] = keys.numel(); + + identify_tile_ranges_kernel<<>>(keys.data_ptr(), indices.data_ptr(), + keys.numel()); + + thrust::device_ptr begin_it = thrust::device_pointer_cast(indices.data_ptr()); + auto end_it = begin_it + indices.numel(); + thrust::inclusive_scan(begin_it, end_it, begin_it, thrust::maximum{}); + + return indices; +} + +std::tuple +rasterizer_cuda_forward(torch::Tensor gaussian_tensor, torch::Tensor indices, torch::Tensor tile_indices, + const TorchRasterizationSettings& settings) +{ + dim3 block_dim(TILE_SIZE, TILE_SIZE); + dim3 grid_dim((settings.image_width + block_dim.x - 1) / block_dim.x, + (settings.image_height + block_dim.y - 1) / block_dim.y); + + torch::Tensor output_image = torch::empty({settings.image_height, settings.image_width, 3}, + torch::device(torch::kCUDA).dtype(torch::kFloat32)); + + torch::Tensor final_Ts = torch::empty({settings.image_height, settings.image_width}, output_image.options()); + + torch::Tensor final_idx = + torch::empty({settings.image_height, settings.image_width}, output_image.options().dtype(torch::kInt32)); + + Gaussian* gaussian_data_ptr = reinterpret_cast(gaussian_tensor.data_ptr()); + + render_image_kernel<<>>(gaussian_data_ptr, indices.data_ptr(), + tile_indices.data_ptr(), output_image.data_ptr(), + settings.image_width, settings.image_height, + final_Ts.data_ptr(), final_idx.data_ptr()); + + return std::make_tuple(output_image, final_Ts, final_idx); +} + +void sort_by_keys(torch::Tensor indices, torch::Tensor keys) +{ + auto device_indices = thrust::device_pointer_cast(indices.data_ptr()); + auto device_keys = thrust::device_pointer_cast(keys.data_ptr()); + + thrust::sort_by_key(device_keys, device_keys + keys.numel(), device_indices); +} + +float fov_to_focal(float fov, float pixels) +{ + return pixels / (2.0f * tan(fov / 2.0f)); +} + +std::tuple duplicate_keys(torch::Tensor block_sums, torch::Tensor gaussian_tensor, + torch::Tensor radii, int numBlocks_in_preprocess, + int threads_in_preprocess, + const TorchRasterizationSettings& settings) +{ + // reuse numBlocks and threads count for calculate_keys_and_indices_kernel, so we are sure + // restoring global exclusive sum will work as intended + const auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); + int num_of_gaussians = radii.numel(); + + auto block_sums_device_ptr = thrust::device_pointer_cast(block_sums.data_ptr()); + + thrust::exclusive_scan(block_sums_device_ptr, block_sums_device_ptr + block_sums.numel(), block_sums_device_ptr); + + Gaussian* gaussian_data_ptr = reinterpret_cast(gaussian_tensor.data_ptr()); + const uint8_t* data_ptr = gaussian_tensor.data_ptr(); + const size_t offset_to_last_size_occupied_tiles = + (num_of_gaussians - 1) * sizeof(Gaussian) + offsetof(Gaussian, size_occupied_tiles); + int32_t total_size = 0; + if(num_of_gaussians > 0) + cudaMemcpy(&total_size, data_ptr + offset_to_last_size_occupied_tiles, sizeof(int32_t), cudaMemcpyDeviceToHost); + total_size += (block_sums.numel() > 0) ? block_sums[block_sums.numel() - 1].item() : 0; + + torch::Tensor keys = torch::empty({total_size}, options_int32.dtype(torch::kInt64)); + torch::Tensor indices = torch::empty({total_size}, options_int32.dtype(torch::kInt32)); + calculate_keys_and_indices_kernel<<>>( + gaussian_data_ptr, block_sums.data_ptr(), radii.data_ptr(), keys.data_ptr(), + indices.data_ptr(), radii.sizes()[0], total_size, settings.image_width, settings.image_height); + + return std::make_tuple(keys, indices); +} + +std::tuple +rasterizer_forward_preprocessing(torch::Tensor means_3d, torch::Tensor shs, torch::Tensor colors_precomp, + torch::Tensor opacities, torch::Tensor scales, torch::Tensor rotations, + torch::Tensor cov3D_precomp, const TorchRasterizationSettings& settings) +{ + const float scale_modifier = float(settings.scale_modifier); + const int max_coeff = (settings.max_sh_degree + 1) * (settings.max_sh_degree + 1); + const int coeff = (settings.sh_degree + 1) * (settings.sh_degree + 1); + + if(shs.numel()) + { + assert(settings.sh_degree >= 0 && settings.sh_degree <= 4); + assert(shs.size(1) >= coeff); + assert(settings.max_sh_degree == 3); + } + + const int threads = 768; + const int num_of_gaussians = means_3d.numel() / 3; + const int numBlocks = (num_of_gaussians + threads - 1) / threads; + + const auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); + const int size_of_gaussian = sizeof(Gaussian); + + const auto options_byte = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + torch::Tensor radii = torch::empty({num_of_gaussians}, options_int32); + torch::Tensor gaussian_tensor = torch::empty({num_of_gaussians * size_of_gaussian}, options_byte); + + Gaussian* gaussian_data_ptr = reinterpret_cast(gaussian_tensor.data_ptr()); + + const auto focal_y = settings.image_height / (2 * settings.tanfovy); + const auto focal_x = settings.image_width / (2 * settings.tanfovx); + const auto c_x = settings.image_width / 2.0f; + const auto c_y = settings.image_height / 2.0f; + + // here we will keep all sizes sum per each block + torch::Tensor block_sums = torch::empty({numBlocks}, options_int32.dtype(torch::kInt32)); + + const float* cov3D_precomp_data_ptr = cov3D_precomp.numel() > 0 ? cov3D_precomp.data_ptr() : nullptr; + const float* colors_precomp_data_ptr = colors_precomp.numel() > 0 ? colors_precomp.data_ptr() : nullptr; + + torch::Tensor cov_3d = cov3D_precomp.numel() + ? cov3D_precomp + : torch::empty({num_of_gaussians, 6}, torch::dtype(torch::kFloat).device(torch::kCUDA)); + + rasterizer_preprocessing_kernel<<>>( + means_3d.data_ptr(), shs.data_ptr(), scales.data_ptr(), rotations.data_ptr(), + settings.proj_matrix.data_ptr(), settings.view_matrix.data_ptr(), + settings.campos.data_ptr(), opacities.data_ptr(), cov3D_precomp_data_ptr, colors_precomp_data_ptr, + scale_modifier, max_coeff, settings.sh_degree, num_of_gaussians, focal_x, focal_y, settings.tanfovx, + settings.tanfovy, c_x, c_y, settings.image_width, settings.image_height, radii.data_ptr(), + block_sums.data_ptr(), gaussian_data_ptr, cov_3d.data_ptr()); + + return std::make_tuple(radii, gaussian_tensor, cov_3d, block_sums, numBlocks, threads); +} + +std::vector rasterizer_forward_deepsense(torch::Tensor means_3d, torch::Tensor means_2d, + torch::Tensor shs, torch::Tensor colors_precomp, + torch::Tensor opacities, torch::Tensor scales, + torch::Tensor rotations, torch::Tensor cov3D_precomp, + const TorchRasterizationSettings& settings) +{ + CHECK_INPUT(means_3d); + CHECK_INPUT(means_2d); + CHECK_INPUT(opacities); + + TORCH_CHECK((shs.numel() > 0 && shs.is_contiguous() && shs.is_cuda()) || shs.numel() == 0, "shs error"); + TORCH_CHECK((scales.numel() > 0 && scales.is_contiguous() && scales.is_cuda()) || scales.numel() == 0, + "scales error"); + TORCH_CHECK((rotations.numel() > 0 && rotations.is_contiguous() && rotations.is_cuda()) || rotations.numel() == 0, + "rotations error"); + TORCH_CHECK((cov3D_precomp.numel() > 0 && cov3D_precomp.is_contiguous() && cov3D_precomp.is_cuda()) || + (cov3D_precomp.numel() == 0 && scales.numel() > 0 && scales.sizes()[0] == rotations.sizes()[0]), + "cov3D error"); + TORCH_CHECK((colors_precomp.numel() > 0 && colors_precomp.is_contiguous() && colors_precomp.is_cuda()) || + (colors_precomp.numel() == 0 && shs.numel() > 0), + "SHS error"); + + auto [radii, gaussian_tensor, cov_3d, block_sums, num_blocks_preprocessing, num_threads_preprocessing] = + rasterizer_forward_preprocessing(means_3d, shs, colors_precomp, opacities, scales, rotations, cov3D_precomp, + settings); + + auto [keys, indices] = duplicate_keys(block_sums, gaussian_tensor, radii, num_blocks_preprocessing, + num_threads_preprocessing, settings); + + sort_by_keys(indices, keys); + + torch::Tensor tile_indices = identify_tile_ranges(keys, settings); + auto [rendered_image, final_Ts, final_idx] = + rasterizer_cuda_forward(gaussian_tensor, indices, tile_indices, settings); + + torch::Tensor clamped_colors = torch::ones({means_3d.size(0), 3}, means_3d.options()); + + return {rendered_image, radii, gaussian_tensor, final_Ts, final_idx, clamped_colors, cov_3d, indices, tile_indices}; +} + +torch::Tensor calculate_compensation(torch::Tensor a) +{ + auto b = a.clone(); + b.select(1, 0) -= 0.3f; + b.select(1, 2) -= 0.3f; + + auto det_a = a.select(1, 0) * a.select(1, 2) - torch::pow(a.select(1, 1), 2); + auto det_b = b.select(1, 0) * b.select(1, 2) - torch::pow(b.select(1, 1), 2); + + auto result = det_a / det_b; + + return torch::sqrt(torch::max(result, torch::zeros(result.sizes(), result.options()))); +} + +int sh_degree(int num_bases) +{ + switch(num_bases) + { + case 1: + return 0; + case 4: + return 1; + case 9: + return 2; + case 16: + return 3; + default: + return 3; + } +} + +std::tuple tensors_from_gaussians(torch::Tensor gaussians) +{ + const Gaussian* gaussian_data_ptr = reinterpret_cast(gaussians.data_ptr()); + const int num_of_gaussians = gaussians.numel() / sizeof(Gaussian); + + const auto options = torch::dtype(torch::kFloat).device(torch::kCUDA); + + torch::Tensor xy = torch::empty({num_of_gaussians, 2}, options); + torch::Tensor colors = torch::empty({num_of_gaussians, 3}, options); + torch::Tensor conic = torch::empty({num_of_gaussians, 3}, options); + + const auto threads = 768; + const int blocks = (num_of_gaussians + threads - 1) / threads; + + tensors_from_gaussians_kernel<<>>( + gaussian_data_ptr, reinterpret_cast(xy.data_ptr()), + reinterpret_cast(colors.data_ptr()), reinterpret_cast(conic.data_ptr()), + num_of_gaussians); + + return std::make_tuple(xy, colors, conic); +} + +std::vector rasterizer_backward_deepsense( + torch::Tensor means_3d, torch::Tensor means2d, torch::Tensor shs, torch::Tensor colors_precomp, + torch::Tensor opacities, torch::Tensor scales, torch::Tensor rotations, torch::Tensor cov3D_precomp, + torch::Tensor gaussians, torch::Tensor radii, torch::Tensor colors_clamped, torch::Tensor final_Ts, + torch::Tensor final_index, torch::Tensor cov3D, torch::Tensor indices, torch::Tensor tile_indices, + torch::Tensor v_image, const TorchRasterizationSettings& settings) +{ + const auto num_of_gaussians = means_3d.size(0); + + const dim3 block_dim(TILE_SIZE, TILE_SIZE, 1); + const dim3 grid_dim((settings.image_width + block_dim.x - 1) / block_dim.x, + (settings.image_height + block_dim.y - 1) / block_dim.y, 1); + + const auto num_of_tiles = grid_dim.x * grid_dim.y; + torch::Tensor tile_bins = torch::empty({num_of_tiles, 2}, means_3d.options().dtype(torch::kInt32)); + + using torch::indexing::None; + using torch::indexing::Slice; + + tile_bins.index_put_({Slice(0, num_of_tiles), 0}, tile_indices.index({Slice(0, num_of_tiles)})); + tile_bins.index_put_({Slice(0, num_of_tiles), 1}, tile_indices.index({Slice(1, num_of_tiles + 1)})); + + torch::Tensor v_alpha = torch::zeros_like(v_image.index({"...", 0})); + + auto [v_xy, v_conic, v_colors, v_opacity] = + rasterize_backward_tensor(settings.image_height, settings.image_width, TILE_SIZE, indices, tile_bins, gaussians, + opacities, settings.bg, final_Ts, final_index, v_image, v_alpha); + + torch::Tensor viewdirs = means_3d.detach() - settings.campos; + v_colors = v_colors * colors_clamped; + + auto v_shs = + compute_sh_backward_tensor(num_of_gaussians, sh_degree(shs.size(-2)), settings.sh_degree, viewdirs, v_colors); + + const auto focal_x = fov_to_focal(2 * atan(settings.tanfovx), settings.image_width); + const auto focal_y = fov_to_focal(2 * atan(settings.tanfovy), settings.image_height); + const auto c_x = settings.image_width / 2.0f; + const auto c_y = settings.image_height / 2.0f; + + // TODO: remove compensation + // TODO: compensation in + auto [xy, colors, conics] = tensors_from_gaussians(gaussians); + auto compensation = calculate_compensation(conics.detach()); + torch::Tensor v_depth = torch::zeros({num_of_gaussians, 1}, v_xy.options()); + torch::Tensor view_matrix = settings.view_matrix; + + auto [v_cov2d, v_cov3d, v_mean3d, v_scale, v_rotation] = project_gaussians_backward_tensor( + num_of_gaussians, means_3d, scales, settings.scale_modifier, rotations, view_matrix, focal_x, focal_y, c_x, c_y, + settings.image_height, settings.image_width, gaussians, cov3D, radii, compensation, v_xy, v_depth, v_conic); + + v_xy = torch::cat({v_xy, torch::zeros({v_xy.size(0), 1}, v_xy.options())}, 1); + + return {v_mean3d, v_xy, v_shs, v_colors, v_opacity, v_scale, v_rotation, v_cov3d}; +} diff --git a/ds_splat/__init__.py b/ds_splat/__init__.py new file mode 100644 index 0000000..07769e0 --- /dev/null +++ b/ds_splat/__init__.py @@ -0,0 +1,4 @@ +import torch # noqa: F401 +from .rasterizer import GaussianRasterizer, GaussianRasterizationSettings + +__all__ = ["GaussianRasterizer", "GaussianRasterizationSettings"] diff --git a/ds_splat/rasterizer.py b/ds_splat/rasterizer.py new file mode 100644 index 0000000..a2d6c10 --- /dev/null +++ b/ds_splat/rasterizer.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +from dataclasses import dataclass +import torch +import ds_splat_cuda as _cuda_impl + + +class DsRasterizerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + means_3d, + means_2d, + shs, + colors_precomp, + opacities, + scales, + rotations, + cov_3d_precomp, + raster_settings, + ): + + outputs = _cuda_impl.forward_deepsense( + means_3d, + means_2d, + shs, + colors_precomp, + opacities, + scales, + rotations, + cov_3d_precomp, + raster_settings, + ) + + ( + rendered_image, + radii, + gaussian_tensor, + final_Ts, + final_idx, + clamped_colors, + cov_3d, + indices, + tile_indices, + ) = outputs + + ctx.raster_settings = raster_settings + ctx.save_for_backward( + means_3d, + means_2d, + shs, + colors_precomp, + opacities, + scales, + rotations, + cov_3d_precomp, + radii, + final_Ts, + final_idx, + clamped_colors, + gaussian_tensor, + cov_3d, + indices, + tile_indices, + ) + + return rendered_image, radii + + @staticmethod + def backward(ctx, rendered_img_grad, _): + ( + means_3d, + means_2d, + shs, + colors_precomp, + opacities, + scales, + rotations, + cov_3d_precomp, + radii, + final_Ts, + final_idx, + clamped_colors, + gaussian_tensor, + cov_3d, + indices, + tile_indices, + ) = ctx.saved_tensors + + raster_settings = ctx.raster_settings + + output = _cuda_impl.backward_deepsense( + means_3d, + means_2d, + shs, + colors_precomp, + opacities, + scales, + rotations, + cov_3d_precomp, + gaussian_tensor, + radii, + clamped_colors, + final_Ts, + final_idx, + cov_3d, + indices, + tile_indices, + rendered_img_grad, + raster_settings, + ) + + return (*output, None) + + +@dataclass +class GaussianRasterizationSettings: + image_height: int + image_width: int + tanfovx: float + tanfovy: float + bg: torch.Tensor + scale_modifier: float + viewmatrix: torch.Tensor + projmatrix: torch.Tensor + sh_degree: int + max_sh_degree: int + campos: torch.Tensor + prefiltered: bool + debug: bool + + +class GaussianRasterizer(torch.nn.Module): + def __init__(self, raster_settings: GaussianRasterizationSettings): + super(GaussianRasterizer, self).__init__() + self._settings = raster_settings + + @property + def settings_for_cpp_code(self): + settings = _cuda_impl.RasterizationSettings() + + settings.image_height = self._settings.image_height + settings.image_width = self._settings.image_width + settings.tanfovx = self._settings.tanfovx + settings.tanfovy = self._settings.tanfovy + settings.bg = self._settings.bg + settings.scale_modifier = self._settings.scale_modifier + settings.view_matrix = self._settings.viewmatrix.t().contiguous() + settings.proj_matrix = self._settings.projmatrix.t().contiguous() + settings.sh_degree = self._settings.sh_degree + settings.max_sh_degree = self._settings.max_sh_degree + settings.campos = self._settings.campos + settings.prefiltered = self._settings.prefiltered + settings.debug = self._settings.debug + + return settings + + def forward( + self, + means_3d: torch.Tensor, + means_2d: torch.Tensor, + shs: torch.Tensor, + colors_precomp: torch.Tensor, + opacities: torch.Tensor, + scales: torch.Tensor, + rotations: torch.Tensor, + cov_3d_precomp: torch.Tensor, + ): + + if cov_3d_precomp is None: + cov_3d_precomp = torch.Tensor().cuda() + + if colors_precomp is None: + colors_precomp = torch.Tensor().cuda() + + if scales is None: + scales = torch.Tensor().cuda() + + if rotations is None: + rotations = torch.Tensor().cuda() + + if shs is None: + shs = torch.Tensor().cuda() + + return DsRasterizerFunction.apply( + means_3d, + means_2d, + shs, + colors_precomp, + opacities, + scales, + rotations, + cov_3d_precomp, + self.settings_for_cpp_code, + ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..45efecc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,82 @@ +[build-system] +requires = ["setuptools", "wheel", "torch"] +build-backend = "setuptools.build_meta" + + +[project] +name = "ds-splat" +version = "0.0.0" +description = "A CUDA-based gaussian splatting rasterizer extension for PyTorch." +readme = "README.md" +authors = [{ name = "deepsense.ai", email = "contact@deepsense.ai" }] +requires-python = ">=3.7" +keywords = ["pytorch", "cuda", "rasterizer", "deep learning"] +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] + +[project.optional-dependencies] +test = ["torch"] + +[tool.setuptools] +include-package-data = true + +[tool.semantic_release] +assets = [] +build_command_env = [] +commit_message = "{version}\n\nAutomatically generated by python-semantic-release" +commit_parser = "angular" +logging_use_named_masks = false +major_on_zero = true +allow_zero_version = true +no_git_verify = false +tag_format = "v{version}" +version_variables = ["setup.py:__version__"] +build_command = "pip install build && python -m build" + +[tool.semantic_release.branches.main] +match = "(main|master)" +prerelease_token = "rc" +prerelease = false + +[tool.semantic_release.changelog] +template_dir = "templates" +changelog_file = "CHANGELOG.md" +exclude_commit_patterns = [] + +[tool.semantic_release.changelog.environment] +block_start_string = "{%" +block_end_string = "%}" +variable_start_string = "{{" +variable_end_string = "}}" +comment_start_string = "{#" +comment_end_string = "#}" +trim_blocks = false +lstrip_blocks = false +newline_sequence = "\n" +keep_trailing_newline = false +extensions = [] +autoescape = true + +[tool.semantic_release.commit_author] +env = "GIT_COMMIT_AUTHOR" +default = "semantic-release " + +[tool.semantic_release.commit_parser_options] +allowed_tags = ["build", "chore", "ci", "docs", "feat", "fix", "perf", "style", "refactor", "test"] +minor_tags = ["feat"] +patch_tags = ["fix", "perf"] +default_bump_level = 0 + +[tool.semantic_release.remote] +name = "origin" +type = "github" +ignore_token_for_push = false +insecure = false + +[tool.semantic_release.publish] +dist_glob_patterns = ["dist/*"] +upload_to_vcs_release = true + diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..1833f26 --- /dev/null +++ b/setup.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +from setuptools import find_packages, setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +__version__ = "0.0.0" + +def get_cuda_extension(): + extra_compile_args = {"cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math"]} + + return CUDAExtension( + "ds_splat_cuda", + sources=[ + "cpp/rasterizer_python.cpp", + "cpp/rasterizer_torch.cu", + "cpp/rasterizer_cuda.cu", + "cpp/rasterizer_kernels.cu", + "cpp/gsplat/backward.cu", + "cpp/gsplat/bindings.cu", + ], + extra_compile_args=extra_compile_args, + ) + + +setup( + name="ds-splat", + ext_modules=[get_cuda_extension()], + cmdclass={"build_ext": BuildExtension}, + packages=find_packages(), +) diff --git a/tests/rasterizer_test.cpp b/tests/rasterizer_test.cpp new file mode 100644 index 0000000..b03af91 --- /dev/null +++ b/tests/rasterizer_test.cpp @@ -0,0 +1,212 @@ +#include +#define CATCH_CONFIG_MAIN + +#include "ds_cuda_rasterizer/rasterizer_cuda.hpp" +#include "ds_cuda_rasterizer/rasterizer_torch.hpp" + +#include +#include +#include +#include + +torch::Tensor identify_tile_ranges(torch::Tensor keys, const TorchRasterizationSettings& settings); +std::tuple prepare_keys(torch::Tensor means_2d, torch::Tensor radii, + const TorchRasterizationSettings& settings); +void sort_by_keys(torch::Tensor indices, torch::Tensor keys); + +std::int64_t get_key(std::int32_t tile_index, std::int32_t depth) +{ + return std::int64_t(tile_index) << 32 | depth; +} + +// TEST(Rasterizer, PrepareKeysIndicesSimple) +// { +// static constexpr int N = 10; + +// auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + +// std::array ndc_array = {1437, 185, 315, 1087, 507, 29, 591, 807, 1429, 209, +// 164, 979, 119, 687, 104, 1088, 968, 139, 730, 607, +// 1216, 679, 1368, 441, 1485, 995, 893, 690, 21, 835}; + +// std::array radii_array = {20, 6, 10, 4, 4, 8, 11, 17, 11, 31}; +// torch::Tensor ndc = torch::from_blob(ndc_array.data(), {N, 4}, torch::dtype(torch::kFloat32)).cuda(); +// torch::Tensor radii = torch::from_blob(radii_array.data(), {N}, torch::dtype(torch::kInt32)).cuda(); + +// const TorchRasterizationSettings raster_settings = { +// .image_height = 1084, .image_width = 1920, .tanfovx = 0.7673294196293707, .tanfovy = 0.4332052624853496}; + +// auto [keys, indices] = prepare_keys(ndc, radii, raster_settings); + +// constexpr auto KEYS_NUM = 40; +// EXPECT_EQ(keys.numel(), KEYS_NUM); +// EXPECT_EQ(keys.numel(), indices.numel()); + +// const std::array ref_indices = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, +// 2, 2, 3, 3, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 8, 8, 8, 8}; + +// auto indices_cpu = indices.cpu(); +// const std::span indices_span(indices_cpu.data_ptr(), KEYS_NUM); +// EXPECT_THAT(indices_span, testing::ElementsAreArray(ref_indices)); + +// const std::array ref_keys = { +// 5533052272640, 5537347239936, 5541642207232, 5545937174528, 6048448348160, 6052743315456, 6057038282752, +// 6061333250048, 6563844423680, 6568139390976, 6572434358272, 6576729325568, 16266146873344, +// 16270441840640, 16781542948864, 16785837916160, 25410179080192, 25414474047488, 25925575155712, +// 25929870123008, 26440971231232, 26445266198528, 5206648864768, 5210943832064, 21677820870656, +// 22193216946176, 31212652134400, 31216947101696, 31728048209920, 31732343177216, 19259784167424, +// 19264079134720, 19268374102016, 19775180242944, 19779475210240, 19783770177536, 31835444690944, +// 31839739658240, 32350840766464, 32355135733760}; + +// auto keys_cpu = keys.cpu(); +// const std::span keys_span(keys_cpu.data_ptr(), KEYS_NUM); +// EXPECT_THAT(keys_span, testing::ElementsAreArray(ref_keys)); +// } + +TEST(Rasterizer, IdentifyTileRanges) +{ + static constexpr int N = 10; + std::array keys_array = {get_key(0, 40), get_key(1, 20), get_key(1, 50), get_key(2, 60), + get_key(3, 10), get_key(3, 20), get_key(4, 20), get_key(5, 40), + get_key(5, 30), get_key(5, 50)}; + + torch::Tensor keys = torch::from_blob(keys_array.data(), {N}, torch::dtype(torch::kInt64)).cuda(); + + TorchRasterizationSettings raster_settings = {}; + + raster_settings.image_height = 30; + raster_settings.image_width = 40; + + auto indices = identify_tile_ranges(keys, raster_settings); + + std::array ref_indices = {0, 1, 3, 4, 6, 7, 10}; + auto cpu_indices = indices.cpu(); + std::span indices_span(cpu_indices.data_ptr(), cpu_indices.numel()); + + EXPECT_THAT(indices_span, testing::ElementsAreArray(ref_indices)); +} + +TEST(Rasterizer, IdentifyTileRangesNoTileEntries) +{ + static constexpr int N = 10; + std::array keys_array = {get_key(0, 40), get_key(1, 20), get_key(1, 50), get_key(3, 60), + get_key(3, 10), get_key(3, 20), get_key(3, 20), get_key(5, 40), + get_key(5, 30), get_key(5, 50)}; + + torch::Tensor keys = torch::from_blob(keys_array.data(), {N}, torch::dtype(torch::kInt64)).cuda(); + + TorchRasterizationSettings raster_settings = {}; + + raster_settings.image_height = 30; + raster_settings.image_width = 40; + + auto indices = identify_tile_ranges(keys, raster_settings); + + std::array ref_indices = {0, 1, 3, 3, 7, 7, 10}; + auto cpu_indices = indices.cpu(); + std::span indices_span(cpu_indices.data_ptr(), cpu_indices.numel()); + + EXPECT_THAT(indices_span, testing::ElementsAreArray(ref_indices)); +} + +template +auto with_time_report( + typename std::enable_if>, std::string_view>::type step_name, + Func function, Args&&... args) +{ + const auto start_ts = std::chrono::high_resolution_clock::now(); + const auto result = function(std::forward(args)...); + const auto end_ts = std::chrono::high_resolution_clock::now(); + + using namespace std::chrono_literals; + const auto duration = (end_ts - start_ts) / 1.0ms; + + std::cout << "Timing step:" << step_name << " took: " << duration << "ms\n"; + + return result; +} + +template +void with_time_report( + typename std::enable_if>, std::string_view>::type step_name, + Func function, Args&&... args) +{ + + const auto start_ts = std::chrono::high_resolution_clock::now(); + function(std::forward(args)...); + const auto end_ts = std::chrono::high_resolution_clock::now(); + + using namespace std::chrono_literals; + const auto duration = (end_ts - start_ts) / 1.0ms; + + std::cout << "Timing step:" << step_name << " took: " << duration << "ms\n"; +} + +TEST(Rasterizer, DeepsensePyTorchPipelineProfiling) +{ + torch::jit::script::Module input_data = torch::jit::load("input_data_garden.pt"); + torch::Tensor means_3d = input_data.attr("means_3d").toTensor(); + torch::Tensor means_2d = input_data.attr("means_2d").toTensor(); + torch::Tensor shs = input_data.attr("shs").toTensor(); + torch::Tensor opacities = input_data.attr("opacities").toTensor(); + torch::Tensor scales = input_data.attr("scales").toTensor(); + torch::Tensor rotations = input_data.attr("rotations").toTensor(); + torch::Tensor colors_precomp; + torch::Tensor cov_3d_precomp; + + const TorchRasterizationSettings settings = { + .image_height = static_cast(input_data.attr("image_height").toInt()), + .image_width = static_cast(input_data.attr("image_width").toInt()), + .tanfovx = static_cast(input_data.attr("tanfovx").toDouble()), + .tanfovy = static_cast(input_data.attr("tanfovy").toDouble()), + .bg = torch::zeros({3}, torch::device(torch::kCUDA)), + .scale_modifier = static_cast(input_data.attr("scale_modifier").toDouble()), + .view_matrix = input_data.attr("viewmatrix").toTensor().t().contiguous(), + .proj_matrix = input_data.attr("projmatrix").toTensor().t().contiguous(), + .sh_degree = static_cast(input_data.attr("sh_degree").toInt()), + .campos = input_data.attr("campos").toTensor()}; + + std::vector result; + for(auto i = 0; i < 10; ++i) + { + result = with_time_report("rasterizer_deepsense", &rasterizer_forward_deepsense, means_3d, means_2d, shs, + colors_precomp, opacities, scales, rotations, cov_3d_precomp, settings); + } +} + +TEST(Rasterizer, DeepsenseCorePipelineProfiling) +{ + torch::jit::script::Module input_data = torch::jit::load("input_data_garden.pt"); + torch::Tensor means_3d = input_data.attr("means_3d").toTensor(); + torch::Tensor means_2d = input_data.attr("means_2d").toTensor(); + torch::Tensor shs = input_data.attr("shs").toTensor(); + torch::Tensor opacities = input_data.attr("opacities").toTensor(); + torch::Tensor scales = input_data.attr("scales").toTensor(); + torch::Tensor rotations = input_data.attr("rotations").toTensor(); + + const TorchRasterizationSettings settings = { + .image_height = static_cast(input_data.attr("image_height").toInt()), + .image_width = static_cast(input_data.attr("image_width").toInt()), + .tanfovx = static_cast(input_data.attr("tanfovx").toDouble()), + .tanfovy = static_cast(input_data.attr("tanfovy").toDouble()), + .bg = torch::zeros({3}, torch::device(torch::kCUDA)), + .scale_modifier = static_cast(input_data.attr("scale_modifier").toDouble()), + .view_matrix = input_data.attr("viewmatrix").toTensor().t().contiguous(), + .proj_matrix = input_data.attr("projmatrix").toTensor().t().contiguous(), + .sh_degree = static_cast(input_data.attr("sh_degree").toInt()), + .campos = input_data.attr("campos").toTensor()}; + + torch::Tensor output_image = torch::empty({settings.image_height, settings.image_width, 3}, + torch::device(torch::kCUDA).dtype(torch::kFloat32)); + + std::vector result; + for(auto i = 0; i < 10; ++i) + { + with_time_report("rasteruizer_core_deepsense", &rasterizer_forward_core_deepsense, means_3d.data_ptr(), + shs.data_ptr(), opacities.data_ptr(), scales.data_ptr(), + rotations.data_ptr(), means_3d.sizes()[0], settings.view_matrix.data_ptr(), + settings.proj_matrix.data_ptr(), settings.campos.data_ptr(), + settings.image_width, settings.image_height, settings.tanfovx, settings.tanfovy, + settings.scale_modifier, 3, settings.sh_degree, output_image.data_ptr()); + } +}