diff --git a/CMakeLists.txt b/CMakeLists.txt index 308566a..21687eb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,17 +6,17 @@ set(GPU_RUNTIME "CUDA" CACHE STRING "HIP or CUDA") set(OPENCV_DIR "OPENCV_DIR-NOTFOUND" CACHE PATH "Path to the OPENCV installation directory") if(NOT CMAKE_BUILD_TYPE) -set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel." FORCE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel." FORCE) endif() -enable_language(${GPU_RUNTIME}) -set(CMAKE_${GPU_RUNTIME}_STANDARD 17) -set(${GPU_RUNTIME}_STANDARD 17) - if(GPU_RUNTIME STREQUAL "CUDA") set(CMAKE_CUDA_ARCHITECTURES 70 75) - find_package(CUDAToolkit REQUIRED) -else() + find_package(CUDAToolkit) + if (CUDAToolkit-NOTFOUND) + message(WARNING "CUDA toolkit not found, building with CPU support only") + set(GPU_RUNTIME "CPU") + endif() +elseif(GPU_RUNTIME STREQUAL "HIP") set(USE_HIP ON CACHE BOOL "Use HIP for GPU acceleration") if(NOT DEFINED HIP_PATH) @@ -40,6 +40,13 @@ else() list(APPEND CMAKE_PREFIX_PATH "${ROCM_ROOT}") endif() +set(CMAKE_CXX_STANDARD 17) +if((GPU_RUNTIME STREQUAL "CUDA") OR (GPU_RUNTIME STREQUAL "HIP")) + enable_language(${GPU_RUNTIME}) + set(CMAKE_${GPU_RUNTIME}_STANDARD 17) + set(${GPU_RUNTIME}_STANDARD 17) +endif() + if (NOT WIN32 AND NOT APPLE) set(STDPPFS_LIBRARY stdc++fs) endif() @@ -52,38 +59,48 @@ if (NOT WIN32 AND NOT APPLE) endif() set(OpenCV_LIBS opencv_core opencv_imgproc opencv_highgui opencv_calib3d) -add_library(gsplat vendor/gsplat/forward.cu vendor/gsplat/backward.cu vendor/gsplat/bindings.cu vendor/gsplat/helpers.cuh) -if(GPU_RUNTIME STREQUAL "CUDA") - set(GPU_LIBRARIES "cuda") - target_link_libraries(gsplat PUBLIC cuda) -else(GPU_RUNTIME STREQUAL "HIP") - set(GPU_INCLUDE_DIRS "${ROCM_ROOT}/include") - target_compile_definitions(gsplat PRIVATE USE_HIP __HIP_PLATFORM_AMD__) +set(GSPLAT_LIBS gsplat_cpu) +if((GPU_RUNTIME STREQUAL "CUDA") OR (GPU_RUNTIME STREQUAL "HIP")) + add_library(gsplat vendor/gsplat/forward.cu vendor/gsplat/backward.cu vendor/gsplat/bindings.cu vendor/gsplat/helpers.cuh) + list(APPEND GSPLAT_LIBS gsplat) + if(GPU_RUNTIME STREQUAL "CUDA") + set(GPU_LIBRARIES "cuda") + target_link_libraries(gsplat PUBLIC cuda) + set_target_properties(gsplat PROPERTIES CUDA_ARCHITECTURES "70;75") + else(GPU_RUNTIME STREQUAL "HIP") + set(GPU_INCLUDE_DIRS "${ROCM_ROOT}/include") + target_compile_definitions(gsplat PRIVATE USE_HIP __HIP_PLATFORM_AMD__) + endif() + target_include_directories(gsplat PRIVATE + ${PROJECT_SOURCE_DIR}/vendor/glm + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + ${TORCH_INCLUDE_DIRS} + ) + set_target_properties(gsplat PROPERTIES LINKER_LANGUAGE CXX) endif() -target_include_directories(gsplat PRIVATE - ${PROJECT_SOURCE_DIR}/vendor/glm - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} - ${TORCH_INCLUDE_DIRS} -) -set_target_properties(gsplat PROPERTIES LINKER_LANGUAGE CXX) -set_target_properties(gsplat PROPERTIES CUDA_ARCHITECTURES "70;75") +add_library(gsplat_cpu vendor/gsplat-cpu/gsplat_cpu.cpp) +target_include_directories(gsplat_cpu PRIVATE ${TORCH_INCLUDE_DIRS}) add_executable(opensplat opensplat.cpp point_io.cpp nerfstudio.cpp model.cpp kdtree_tensor.cpp spherical_harmonics.cpp cv_utils.cpp utils.cpp project_gaussians.cpp rasterize_gaussians.cpp ssim.cpp optim_scheduler.cpp colmap.cpp input_data.cpp tensor_math.cpp) set_property(TARGET opensplat PROPERTY CXX_STANDARD 17) target_include_directories(opensplat PRIVATE ${PROJECT_SOURCE_DIR}/vendor/glm ${GPU_INCLUDE_DIRS}) -target_link_libraries(opensplat PUBLIC ${STDPPFS_LIBRARY} ${GPU_LIBRARIES} gsplat ${TORCH_LIBRARIES} ${OpenCV_LIBS}) +target_link_libraries(opensplat PUBLIC ${STDPPFS_LIBRARY} ${GPU_LIBRARIES} ${GSPLAT_LIBS} ${TORCH_LIBRARIES} ${OpenCV_LIBS}) if(GPU_RUNTIME STREQUAL "HIP") target_compile_definitions(opensplat PRIVATE USE_HIP __HIP_PLATFORM_AMD__) +elseif(GPU_RUNTIME STREQUAL "CUDA") + target_compile_definitions(opensplat PRIVATE USE_CUDA) endif() if(OPENSPLAT_BUILD_SIMPLE_TRAINER) - add_executable(simple_trainer simple_trainer.cpp project_gaussians.cpp rasterize_gaussians.cpp) + add_executable(simple_trainer simple_trainer.cpp project_gaussians.cpp rasterize_gaussians.cpp cv_utils.cpp) target_include_directories(simple_trainer PRIVATE ${PROJECT_SOURCE_DIR}/vendor/glm ${GPU_INCLUDE_DIRS}) - target_link_libraries(simple_trainer PUBLIC ${GPU_LIBRARIES} gsplat ${TORCH_LIBRARIES} ${OpenCV_LIBS}) + target_link_libraries(simple_trainer PUBLIC ${GPU_LIBRARIES} ${GSPLAT_LIBS} ${TORCH_LIBRARIES} ${OpenCV_LIBS}) set_property(TARGET simple_trainer PROPERTY CXX_STANDARD 17) if(GPU_RUNTIME STREQUAL "HIP") target_compile_definitions(simple_trainer PRIVATE USE_HIP __HIP_PLATFORM_AMD__) + elseif(GPU_RUNTIME STREQUAL "CUDA") + target_compile_definitions(simple_trainer PRIVATE USE_CUDA) endif() endif() diff --git a/README.md b/README.md index 3d8d43c..a882108 100644 --- a/README.md +++ b/README.md @@ -6,15 +6,37 @@ A free and open source implementation of 3D gaussian splatting written in C++, f OpenSplat takes camera poses + sparse points in [COLMAP](https://colmap.github.io/) or [nerfstudio](https://docs.nerf.studio/quickstart/custom_dataset.html) project format and computes a [scene file](https://drive.google.com/file/d/1w-CBxyWNXF3omA8B_IeOsRmSJel3iwyr/view?usp=sharing) (.ply) that can be later imported for viewing, editing and rendering in other [software](https://github.com/MrNeRF/awesome-3D-gaussian-splatting?tab=readme-ov-file#open-source-implementations). +Graphics card recommended, but not required! OpenSplat runs the fastest on NVIDIA and AMD GPUs, but can also run entirely on CPU power (~100x slower). + Commercial use allowed and encouraged under the terms of the [AGPLv3](https://www.tldrlegal.com/license/gnu-affero-general-public-license-v3-agpl-3-0). ✅ -## Build (CUDA) +## Build Requirements: - * **CUDA**: Make sure you have the CUDA compiler (`nvcc`) in your PATH and that `nvidia-smi` is working. https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html - * **libtorch**: Visit https://pytorch.org/get-started/locally/ and select your OS, for package select "LibTorch". Make sure to match your version of CUDA if you want to leverage GPU support in libtorch. * **OpenCV**: `sudo apt install libopencv-dev` should do it. + * **libtorch**: See instructions below + +### CPU + + **libtorch**: Visit https://pytorch.org/get-started/locally/ and select your OS, for package select "LibTorch". For compute platform you can select "CPU". + + Then: + + ```bash + git clone https://github.com/pierotofy/OpenSplat OpenSplat + cd OpenSplat + mkdir build && cd build + cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch/ .. && make -j$(nproc) + ``` + +### CUDA + +Additional requirement: + + * **CUDA**: Make sure you have the CUDA compiler (`nvcc`) in your PATH and that `nvidia-smi` is working. https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html + + **libtorch**: Visit https://pytorch.org/get-started/locally/ and select your OS, for package select "LibTorch". Make sure to match your version of CUDA if you want to leverage GPU support in libtorch. Then: @@ -27,12 +49,13 @@ Requirements: The software has been tested on Ubuntu 20.04 and Windows. With some changes it could run on macOS (help us by opening a PR?). -## Build (ROCm via HIP) -Requirements: +### ROCm via HIP + +Additional requirement: * **ROCm**: Make sure you have the ROCm installed at /opt/rocm. https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html -* **libtorch**: Visit https://pytorch.org/get-started/locally/ and select your OS, for package select "LibTorch". Make sure to match your version of ROCm (5.7) if you want to leverage AMD GPU support in libtorch. -* **OpenCV**: `sudo apt install libopencv-dev` should do it. + +**libtorch**: Visit https://pytorch.org/get-started/locally/ and select your OS, for package select "LibTorch". Make sure to match your version of ROCm (5.7) if you want to leverage AMD GPU support in libtorch. Then: @@ -44,13 +67,18 @@ Then: cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch/ -DGPU_RUNTIME="HIP" -DHIP_ROOT_DIR=/opt/rocm -DOPENSPLAT_BUILD_SIMPLE_TRAINER=ON .. make ``` + In addition, you can leverage Jinja to build the project - ``` - cmake -GNinja -DCMAKE_PREFIX_PATH=/path/to/libtorch/ -DGPU_RUNTIME="HIP" -DHIP_ROOT_DIR=/opt/rocm -DOPENSPLAT_BUILD_SIMPLE_TRAINER=ON .. - jinja - ``` -## Docker Build (CUDA) +```bash +cmake -GNinja -DCMAKE_PREFIX_PATH=/path/to/libtorch/ -DGPU_RUNTIME="HIP" -DHIP_ROOT_DIR=/opt/rocm -DOPENSPLAT_BUILD_SIMPLE_TRAINER=ON .. +jinja +``` + +## Docker Build + +### CUDA + Navigate to the root directory of OpenSplat repo that has Dockerfile and run the following command to build the Docker image: ```bash @@ -70,7 +98,8 @@ docker build \ --build-arg CMAKE_BUILD_TYPE=Release . ``` -## Docker Build (ROCm via HIP) +### ROCm via HIP + Navigate to the root directory of OpenSplat repo that has Dockerfile and run the following command to build the Docker image: ```bash docker build \ @@ -138,11 +167,11 @@ cd /code/build We recently released OpenSplat, so there's lots of work to do. * Support for running on AMD cards (more testing needed) - * Support for running on CPU-only * Improve speed / reduce memory usage * Distributed computation using multiple machines * Real-time training viewer output * Compressed scene outputs + * Automatic filtering * Your ideas? https://github.com/pierotofy/OpenSplat/issues?q=is%3Aopen+is%3Aissue+label%3Aenhancement diff --git a/gsplat.hpp b/gsplat.hpp new file mode 100644 index 0000000..d427655 --- /dev/null +++ b/gsplat.hpp @@ -0,0 +1,12 @@ +#ifndef GSPLAT_H +#define GSPLAT_H + +#include "vendor/gsplat/config.h" + +#if defined(USE_HIP) || defined(USE_CUDA) +#include "vendor/gsplat/bindings.h" +#endif + +#include "vendor/gsplat-cpu/bindings.h" + +#endif \ No newline at end of file diff --git a/model.cpp b/model.cpp index f6f835e..86a6d2f 100644 --- a/model.cpp +++ b/model.cpp @@ -4,10 +4,11 @@ #include "project_gaussians.hpp" #include "rasterize_gaussians.hpp" #include "tensor_math.hpp" -#include "vendor/gsplat/config.h" +#include "gsplat.hpp" + #ifdef USE_HIP #include -#else +#elif defined(USE_CUDA) #include #endif @@ -79,31 +80,62 @@ torch::Tensor Model::forward(Camera& cam, int step){ float fovY = 2.0f * std::atan(height / (2.0f * fy)); torch::Tensor projMat = projectionMatrix(0.001f, 1000.0f, fovX, fovY, device); - - TileBounds tileBounds = std::make_tuple((width + BLOCK_X - 1) / BLOCK_X, - (height + BLOCK_Y - 1) / BLOCK_Y, - 1); - torch::Tensor colors = torch::cat({featuresDc.index({Slice(), None, Slice()}), featuresRest}, 1); - auto p = ProjectGaussians::apply(means, - torch::exp(scales), - 1, - quats / quats.norm(2, {-1}, true), - viewMat, - torch::matmul(projMat, viewMat), - fx, - fy, - cx, - cy, - height, - width, - tileBounds); - xys = p[0]; - torch::Tensor depths = p[1]; - radii = p[2]; - torch::Tensor conics = p[3]; - torch::Tensor numTilesHit = p[4]; + torch::Tensor conics; + torch::Tensor depths; // GPU-only + torch::Tensor numTilesHit; // GPU-only + torch::Tensor cov2d; // CPU-only + torch::Tensor camDepths; // CPU-only + torch::Tensor rgb; + + if (device == torch::kCPU){ + auto p = ProjectGaussiansCPU::apply(means, + torch::exp(scales), + 1, + quats / quats.norm(2, {-1}, true), + viewMat, + torch::matmul(projMat, viewMat), + fx, + fy, + cx, + cy, + height, + width); + xys = p[0]; + radii = p[1]; + conics = p[2]; + cov2d = p[3]; + camDepths = p[4]; + }else{ + #if defined(USE_HIP) || defined(USE_CUDA) + + TileBounds tileBounds = std::make_tuple((width + BLOCK_X - 1) / BLOCK_X, + (height + BLOCK_Y - 1) / BLOCK_Y, + 1); + auto p = ProjectGaussians::apply(means, + torch::exp(scales), + 1, + quats / quats.norm(2, {-1}, true), + viewMat, + torch::matmul(projMat, viewMat), + fx, + fy, + cx, + cy, + height, + width, + tileBounds); + + xys = p[0]; + depths = p[1]; + radii = p[2]; + conics = p[3]; + numTilesHit = p[4]; + #else + throw std::runtime_error("GPU support not built, use --cpu"); + #endif + } if (radii.sum().item() == 0.0f) @@ -115,22 +147,47 @@ torch::Tensor Model::forward(Camera& cam, int step){ torch::Tensor viewDirs = means.detach() - T.transpose(0, 1).to(device); viewDirs = viewDirs / viewDirs.norm(2, {-1}, true); int degreesToUse = (std::min)(step / shDegreeInterval, shDegree); - torch::Tensor rgbs = SphericalHarmonics::apply(degreesToUse, viewDirs, colors); - rgbs = torch::clamp_min(rgbs + 0.5f, 0.0f); - + torch::Tensor rgbs; - torch::Tensor rgb = RasterizeGaussians::apply( - xys, - depths, - radii, - conics, - numTilesHit, - rgbs, // TODO: why not sigmod? - torch::sigmoid(opacities), - height, - width, - backgroundColor); + if (device == torch::kCPU){ + rgbs = SphericalHarmonicsCPU::apply(degreesToUse, viewDirs, colors); + }else{ + #if defined(USE_HIP) || defined(USE_CUDA) + rgbs = SphericalHarmonics::apply(degreesToUse, viewDirs, colors); + #endif + } + rgbs = torch::clamp_min(rgbs + 0.5f, 0.0f); + + + if (device == torch::kCPU){ + rgb = RasterizeGaussiansCPU::apply( + xys, + radii, + conics, + rgbs, + torch::sigmoid(opacities), + cov2d, + camDepths, + height, + width, + backgroundColor); + }else{ + #if defined(USE_HIP) || defined(USE_CUDA) + rgb = RasterizeGaussians::apply( + xys, + depths, + radii, + conics, + numTilesHit, + rgbs, + torch::sigmoid(opacities), + height, + width, + backgroundColor); + #endif + } + rgb = torch::clamp_max(rgb, 1.0f); return rgb; @@ -391,11 +448,14 @@ void Model::afterTrain(int step){ xysGradNorm = torch::Tensor(); visCounts = torch::Tensor(); max2DSize = torch::Tensor(); -#ifdef USE_HIP - c10::hip::HIPCachingAllocator::emptyCache(); -#else - c10::cuda::CUDACachingAllocator::emptyCache(); -#endif + + if (device != torch::kCPU){ + #ifdef USE_HIP + c10::hip::HIPCachingAllocator::emptyCache(); + #elif defined(USE_CUDA) + c10::cuda::CUDACachingAllocator::emptyCache(); + #endif + } } } diff --git a/opensplat.cpp b/opensplat.cpp index 22eb8a6..e0c410d 100644 --- a/opensplat.cpp +++ b/opensplat.cpp @@ -17,6 +17,7 @@ int main(int argc, char *argv[]){ ("s,save-every", "Save output scene every these many steps (set to -1 to disable)", cxxopts::value()->default_value("-1")) ("val", "Withhold a camera shot for validating the scene loss") ("val-image", "Filename of the image to withhold for validating scene loss", cxxopts::value()->default_value("random")) + ("cpu", "Force CPU execution") ("n,num-iters", "Number of iterations to run", cxxopts::value()->default_value("30000")) ("d,downscale-factor", "Scale input images by this factor.", cxxopts::value()->default_value("1")) @@ -37,7 +38,7 @@ int main(int argc, char *argv[]){ ("h,help", "Print usage") ; options.parse_positional({ "input" }); - options.positional_help("[nerfstudio project path]"); + options.positional_help("[colmap or nerfstudio project path]"); cxxopts::ParseResult result; try { result = options.parse(argc, argv); @@ -76,10 +77,12 @@ int main(int argc, char *argv[]){ const float splitScreenSize = result["split-screen-size"].as(); torch::Device device = torch::kCPU; + int displayStep = 1; - if (torch::cuda::is_available()) { + if (torch::cuda::is_available() && result.count("cpu") == 0) { std::cout << "Using CUDA" << std::endl; device = torch::kCUDA; + displayStep = 10; }else{ std::cout << "Using CPU" << std::endl; } @@ -90,8 +93,6 @@ int main(int argc, char *argv[]){ cam.loadImage(downScaleFactor); } - - // Withhold a validation camera if necessary auto t = inputData.getCameras(validate, valImage); std::vector cams = std::get<0>(t); @@ -121,7 +122,7 @@ int main(int argc, char *argv[]){ torch::Tensor mainLoss = model.mainLoss(rgb, gt, ssimWeight); mainLoss.backward(); - if (step % 10 == 0) std::cout << "Step " << step << ": " << mainLoss.item() << std::endl; + if (step % displayStep == 0) std::cout << "Step " << step << ": " << mainLoss.item() << std::endl; model.optimizersStep(); model.schedulersStep(step); diff --git a/project_gaussians.cpp b/project_gaussians.cpp index f479919..d57e1d6 100644 --- a/project_gaussians.cpp +++ b/project_gaussians.cpp @@ -1,5 +1,6 @@ #include "project_gaussians.hpp" -#include "vendor/gsplat/bindings.h" + +#if defined(USE_HIP) || defined(USE_CUDA) variable_list ProjectGaussians::forward(AutogradContext *ctx, torch::Tensor means, @@ -38,8 +39,8 @@ variable_list ProjectGaussians::forward(AutogradContext *ctx, ctx->saved_data["fy"] = fy; ctx->saved_data["cx"] = cx; ctx->saved_data["cy"] = cy; - ctx->save_for_backward({ means, scales, quats, viewMat, projMat, cov3d, radii, conics }); + return { xys, depths, radii, conics, numTilesHit, cov3d }; } @@ -86,4 +87,37 @@ tensor_list ProjectGaussians::backward(AutogradContext *ctx, tensor_list grad_ou none, // tileBounds none // clipThresh }; +} + +#endif + +variable_list ProjectGaussiansCPU::apply( + torch::Tensor means, + torch::Tensor scales, + float globScale, + torch::Tensor quats, + torch::Tensor viewMat, + torch::Tensor projMat, + float fx, + float fy, + float cx, + float cy, + int imgHeight, + int imgWidth, + float clipThresh + ){ + + int numPoints = means.size(0); + + auto t = project_gaussians_forward_tensor_cpu(numPoints, means, scales, globScale, + quats, viewMat, projMat, fx, fy, + cx, cy, imgHeight, imgWidth, clipThresh); + + torch::Tensor xys = std::get<0>(t); + torch::Tensor radii = std::get<1>(t); + torch::Tensor conics = std::get<2>(t); + torch::Tensor cov2d = std::get<3>(t); + torch::Tensor camDepths = std::get<4>(t); + + return { xys, radii, conics, cov2d, camDepths }; } \ No newline at end of file diff --git a/project_gaussians.hpp b/project_gaussians.hpp index ece9399..8891d22 100644 --- a/project_gaussians.hpp +++ b/project_gaussians.hpp @@ -3,9 +3,12 @@ #include #include "tile_bounds.hpp" +#include "gsplat.hpp" using namespace torch::autograd; +#if defined(USE_HIP) || defined(USE_CUDA) + class ProjectGaussians : public Function{ public: static variable_list forward(AutogradContext *ctx, @@ -26,5 +29,25 @@ class ProjectGaussians : public Function{ static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs); }; +#endif + +class ProjectGaussiansCPU{ +public: + static variable_list apply( + torch::Tensor means, + torch::Tensor scales, + float globScale, + torch::Tensor quats, + torch::Tensor viewMat, + torch::Tensor projMat, + float fx, + float fy, + float cx, + float cy, + int imgHeight, + int imgWidth, + float clipThresh = 0.01); +}; + #endif \ No newline at end of file diff --git a/rasterize_gaussians.cpp b/rasterize_gaussians.cpp index f7cb655..9023b9b 100644 --- a/rasterize_gaussians.cpp +++ b/rasterize_gaussians.cpp @@ -1,6 +1,5 @@ #include "rasterize_gaussians.hpp" -#include "vendor/gsplat/bindings.h" -#include "vendor/gsplat/config.h" +#include "gsplat.hpp" std::tuple(t); - // Map of alpha-inverse (1 - finalTs = alpha) torch::Tensor finalTs = std::get<1>(t); // Map of tile bin IDs @@ -84,7 +84,6 @@ torch::Tensor RasterizeGaussians::forward(AutogradContext *ctx, ctx->saved_data["imgWidth"] = imgWidth; ctx->saved_data["imgHeight"] = imgHeight; - ctx->save_for_backward({ gaussianIdsSorted, tileBins, xys, conics, colors, opacity, background, finalTs, finalIdx }); return outImg; @@ -106,7 +105,6 @@ tensor_list RasterizeGaussians::backward(AutogradContext *ctx, tensor_list grad_ torch::Tensor finalTs = saved[7]; torch::Tensor finalIdx = saved[8]; - // torch::Tensor v_outAlpha = torch::zeros({imgHeight, imgWidth}, torch::TensorOptions().device(v_outImg.get_device()); torch::Tensor v_outAlpha = torch::zeros_like(v_outImg.index({"...", 0})); auto t = rasterize_backward_tensor(imgHeight, imgWidth, @@ -139,4 +137,99 @@ tensor_list RasterizeGaussians::backward(AutogradContext *ctx, tensor_list grad_ none, // imgWidth none // background }; -} \ No newline at end of file +} + +#endif + +torch::Tensor RasterizeGaussiansCPU::forward(AutogradContext *ctx, + torch::Tensor xys, + torch::Tensor radii, + torch::Tensor conics, + torch::Tensor colors, + torch::Tensor opacity, + torch::Tensor cov2d, + torch::Tensor camDepths, + int imgHeight, + int imgWidth, + torch::Tensor background + ){ + + int numPoints = xys.size(0); + + auto t = rasterize_forward_tensor_cpu(imgWidth, imgHeight, + xys, + conics, + colors, + opacity, + background, + cov2d, + camDepths + ); + // Final image + torch::Tensor outImg = std::get<0>(t); + + torch::Tensor finalTs = std::get<1>(t); + std::vector *px2gid = std::get<2>(t); + + ctx->saved_data["imgWidth"] = imgWidth; + ctx->saved_data["imgHeight"] = imgHeight; + ctx->saved_data["px2gid"] = reinterpret_cast(px2gid); + ctx->save_for_backward({ xys, conics, colors, opacity, background, cov2d, camDepths, finalTs }); + + return outImg; +} + +tensor_list RasterizeGaussiansCPU::backward(AutogradContext *ctx, tensor_list grad_outputs) { + torch::Tensor v_outImg = grad_outputs[0]; + int imgHeight = ctx->saved_data["imgHeight"].toInt(); + int imgWidth = ctx->saved_data["imgWidth"].toInt(); + const std::vector *px2gid = reinterpret_cast *>(ctx->saved_data["px2gid"].toInt()); + + variable_list saved = ctx->get_saved_variables(); + torch::Tensor xys = saved[0]; + torch::Tensor conics = saved[1]; + torch::Tensor colors = saved[2]; + torch::Tensor opacity = saved[3]; + torch::Tensor background = saved[4]; + torch::Tensor cov2d = saved[5]; + torch::Tensor camDepths = saved[6]; + torch::Tensor finalTs = saved[7]; + + torch::Tensor v_outAlpha = torch::zeros_like(v_outImg.index({"...", 0})); + + auto t = rasterize_backward_tensor_cpu(imgHeight, imgWidth, + xys, + conics, + colors, + opacity, + background, + cov2d, + camDepths, + finalTs, + px2gid, + v_outImg, + v_outAlpha); + + delete[] px2gid; + + + torch::Tensor v_xy = std::get<0>(t); + torch::Tensor v_conic = std::get<1>(t); + torch::Tensor v_colors = std::get<2>(t); + torch::Tensor v_opacity = std::get<3>(t); + torch::Tensor none; + + return { v_xy, + none, // radii + v_conic, + v_colors, + v_opacity, + none, // cov2d + none, // camDepths + none, // imgHeight + none, // imgWidth + none // background + }; +} + + diff --git a/rasterize_gaussians.hpp b/rasterize_gaussians.hpp index f73a85f..adb7692 100644 --- a/rasterize_gaussians.hpp +++ b/rasterize_gaussians.hpp @@ -17,6 +17,8 @@ std::tuple{ public: static torch::Tensor forward(AutogradContext *ctx, @@ -33,4 +35,22 @@ class RasterizeGaussians : public Function{ static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs); }; +#endif + +class RasterizeGaussiansCPU : public Function{ +public: + static torch::Tensor forward(AutogradContext *ctx, + torch::Tensor xys, + torch::Tensor radii, + torch::Tensor conics, + torch::Tensor colors, + torch::Tensor opacity, + torch::Tensor cov2d, + torch::Tensor camDepths, + int imgHeight, + int imgWidth, + torch::Tensor background); + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs); +}; + #endif \ No newline at end of file diff --git a/simple_trainer.cpp b/simple_trainer.cpp index b40e6d6..2da1704 100644 --- a/simple_trainer.cpp +++ b/simple_trainer.cpp @@ -1,40 +1,63 @@ #include #include +#include #include #ifdef USE_HIP #include -#else +#elif defined(USE_CUDA) #include #endif -#include - #include #include #include -#include "vendor/gsplat/config.h" #include "project_gaussians.hpp" #include "rasterize_gaussians.hpp" #include "constants.hpp" #include "cv_utils.hpp" +#include "vendor/cxxopts.hpp" using namespace torch::indexing; +namespace fs = std::filesystem; +int main(int argc, char **argv){ + cxxopts::Options options("simple_trainer", "Test program for gsplat execution"); + options.add_options() + ("cpu", "Force CPU execution") + ("width", "Test image width", cxxopts::value()->default_value("256")) + ("height", "Test image height", cxxopts::value()->default_value("256")) + ("iters", "Number of iterations", cxxopts::value()->default_value("1000")) + ("points", "Number of gaussians", cxxopts::value()->default_value("100000")) + ("lr", "Learning rate", cxxopts::value()->default_value("0.01")) + ("render", "Save rendered images to folder", cxxopts::value()->default_value("")) + ("h,help", "Print usage") + ; + cxxopts::ParseResult result; + try { + result = options.parse(argc, argv); + } + catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + std::cerr << options.help() << std::endl; + return EXIT_FAILURE; + } + if (result.count("help")) { + std::cout << options.help() << std::endl; + return EXIT_SUCCESS; + } - - - -int main(int argc, char **argv){ - int width = 256, - height = 256; - int numPoints = 100000; - int iterations = 1000; - float learningRate = 0.01; + int width = result["width"].as(), + height = result["height"].as(); + int numPoints = result["points"].as(); + int iterations = result["iters"].as(); + float learningRate = result["lr"].as(); + std::string render = result["render"].as(); + if (!render.empty() && !fs::exists(render)) fs::create_directories(render); torch::Device device = torch::kCPU; - if (torch::cuda::is_available()) { + if (torch::cuda::is_available() && result.count("cpu") == 0){ std::cout << "Using CUDA" << std::endl; device = torch::kCUDA; }else{ @@ -60,25 +83,32 @@ int main(int argc, char **argv){ (height + BLOCK_Y - 1) / BLOCK_Y, 1); - // torch::Tensor imgSize = torch::tensor({width, height, 1}, device); - // torch::Tensor block = torch::tensor({BLOCK_X, BLOCK_Y, 1}, device); - // Init gaussians -#ifdef USE_HIP -#else +#ifdef USE_CUDA torch::cuda::manual_seed_all(0); #endif + torch::manual_seed(0); // Random points, scales and colors - torch::Tensor means = 2.0 * (torch::rand({numPoints, 3}, device) - 0.5); // Positions [-1, 1] - torch::Tensor scales = torch::rand({numPoints, 3}, device); - torch::Tensor rgbs = torch::rand({numPoints, 3}, device); + torch::Tensor means = 2.0 * (torch::rand({numPoints, 3}, torch::kCPU) - 0.5); // Positions [-1, 1] + torch::Tensor scales = torch::rand({numPoints, 3}, torch::kCPU); + // torch::Tensor means = torch::tensor({{0.5f, 0.5f, -5.0f}, {0.5f, 0.5f, -6.0f}, {0.25f, 0.25f, -4.0f}}, torch::kCPU); + // torch::Tensor scales = torch::tensor({{0.5f, 0.5f, 0.5f}, {1.0f, 1.0f, 1.0f}, {1.0f, 1.0f, 1.0f}}, torch::kCPU); + torch::Tensor rgbs = torch::rand({numPoints, 3}, torch::kCPU); // Random rotations (quaternions) // quats = ( sqrt(1-u) sin(2πv), sqrt(1-u) cos(2πv), sqrt(u) sin(2πw), sqrt(u) cos(2πw)) - torch::Tensor u = torch::rand({numPoints, 1}, device); - torch::Tensor v = torch::rand({numPoints, 1}, device); - torch::Tensor w = torch::rand({numPoints, 1}, device); + torch::Tensor u = torch::rand({numPoints, 1}, torch::kCPU); + torch::Tensor v = torch::rand({numPoints, 1}, torch::kCPU); + torch::Tensor w = torch::rand({numPoints, 1}, torch::kCPU); + + means = means.to(device); + scales = scales.to(device); + rgbs = rgbs.to(device); + u = u.to(device); + v = v.to(device); + w = w.to(device); + torch::Tensor quats = torch::cat({ torch::sqrt(1.0 - u) * torch::sin(2.0 * PI * v), torch::sqrt(1.0 - u) * torch::cos(2.0 * PI * v), @@ -106,66 +136,68 @@ int main(int argc, char **argv){ torch::optim::Adam optimizer({rgbs, means, scales, opacities, quats}, learningRate); torch::nn::MSELoss mseLoss; + torch::Tensor outImg; for (size_t i = 0; i < iterations; i++){ - auto p = ProjectGaussians::apply(means, scales, 1, + if (device == torch::kCPU){ + auto p = ProjectGaussiansCPU::apply(means, scales, 1, quats, viewMat, viewMat, focal, focal, width / 2, height / 2, height, - width, - tileBounds); - -#ifdef USE_HIP - hipError_t err = hipDeviceSynchronize(); - if (err != hipSuccess) { - std::cerr << "hipDeviceSynchronize failed with error: " << hipGetErrorString(err) << std::endl; + width); + + outImg = RasterizeGaussiansCPU::apply( + p[0], // xys + p[1], // radii, + p[2], // conics + torch::sigmoid(rgbs), + torch::sigmoid(opacities), + p[3], // cov2d + p[4], // camDepths + height, + width, + background); + }else{ + #if defined(USE_HIP) || defined(USE_CUDA) + auto p = ProjectGaussians::apply(means, scales, 1, + quats, viewMat, viewMat, + focal, focal, + width / 2, + height / 2, + height, + width, + tileBounds); + + outImg = RasterizeGaussians::apply( + p[0], // xys + p[1], // depths + p[2], // radii, + p[3], // conics + p[4], // numTilesHit + torch::sigmoid(rgbs), + torch::sigmoid(opacities), + height, + width, + background); + #else + throw std::runtime_error("GPU support not built, use --cpu"); + #endif } -#else - torch::cuda::synchronize(); -#endif - - torch::Tensor outImg = RasterizeGaussians::apply( - p[0], // xys - p[1], // depths - p[2], // radii, - p[3], // conics - p[4], // numTilesHit - torch::sigmoid(rgbs), - torch::sigmoid(opacities), - height, - width, - background); - -#ifdef USE_HIP - err = hipDeviceSynchronize(); - if (err != hipSuccess) { - std::cerr << "hipDeviceSynchronize failed with error: " << hipGetErrorString(err) << std::endl; - } -#else - torch::cuda::synchronize(); -#endif outImg.requires_grad_(); torch::Tensor loss = mseLoss(outImg, gtImage); optimizer.zero_grad(); loss.backward(); - -#ifdef USE_HIP - err = hipDeviceSynchronize(); - if (err != hipSuccess) { - std::cerr << "hipDeviceSynchronize failed with error: " << hipGetErrorString(err) << std::endl; - } -#else - torch::cuda::synchronize(); -#endif optimizer.step(); std::cout << "Iteration " << std::to_string(i + 1) << "/" << std::to_string(iterations) << " Loss: " << loss.item() << std::endl; - - // cv::Mat image = tensorToImage(outImg.detach().cpu()); - // cv::cvtColor(image, image, cv::COLOR_RGB2BGR); - // cv::imwrite("render/" + std::to_string(i + 1) + ".png", image); + + if (!render.empty()){ + cv::Mat image = tensorToImage(outImg.detach().cpu()); + cv::cvtColor(image, image, cv::COLOR_RGB2BGR); + cv::imwrite((fs::path(render) / (std::to_string(i + 1) + ".png")).string(), image); + } } } \ No newline at end of file diff --git a/spherical_harmonics.cpp b/spherical_harmonics.cpp index c88b462..6581188 100644 --- a/spherical_harmonics.cpp +++ b/spherical_harmonics.cpp @@ -1,20 +1,4 @@ #include "spherical_harmonics.hpp" -#include "vendor/gsplat/bindings.h" - -int numShBases(int degree){ - switch(degree){ - case 0: - return 1; - case 1: - return 4; - case 2: - return 9; - case 3: - return 16; - default: - return 25; - } -} int degFromSh(int numBases){ switch(numBases){ @@ -37,6 +21,8 @@ torch::Tensor rgb2sh(const torch::Tensor &rgb){ return (rgb - 0.5) / C0; } +#if defined(USE_HIP) || defined(USE_CUDA) + torch::Tensor SphericalHarmonics::forward(AutogradContext *ctx, int degreesToUse, torch::Tensor viewDirs, @@ -67,5 +53,15 @@ tensor_list SphericalHarmonics::backward(AutogradContext *ctx, tensor_list grad_ none, compute_sh_backward_tensor(numPoints, degree, degreesToUse, viewDirs, v_colors) }; - +} + +#endif + +torch::Tensor SphericalHarmonicsCPU::apply(int degreesToUse, + torch::Tensor viewDirs, + torch::Tensor coeffs){ + long long numPoints = coeffs.size(0); + int degree = degFromSh(coeffs.size(-2)); + + return compute_sh_forward_tensor_cpu(numPoints, degree, degreesToUse, viewDirs, coeffs); } \ No newline at end of file diff --git a/spherical_harmonics.hpp b/spherical_harmonics.hpp index a84762d..2ebe1d1 100644 --- a/spherical_harmonics.hpp +++ b/spherical_harmonics.hpp @@ -2,13 +2,15 @@ #define SPHERICAL_HARMONICS_H #include +#include "gsplat.hpp" using namespace torch::autograd; -int numShBases(int degree); int degFromSh(int numBases); torch::Tensor rgb2sh(const torch::Tensor &rgb); +#if defined(USE_HIP) || defined(USE_CUDA) + class SphericalHarmonics : public Function{ public: static torch::Tensor forward(AutogradContext *ctx, @@ -18,4 +20,13 @@ class SphericalHarmonics : public Function{ static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs); }; +#endif + +class SphericalHarmonicsCPU{ +public: + static torch::Tensor apply(int degreesToUse, + torch::Tensor viewDirs, + torch::Tensor coeffs); +}; + #endif \ No newline at end of file diff --git a/tile_bounds.hpp b/tile_bounds.hpp index 7b29dc2..4c1330b 100644 --- a/tile_bounds.hpp +++ b/tile_bounds.hpp @@ -1,3 +1,8 @@ +#ifndef TILE_BOUNDS_H +#define TILE_BOUNDS_H + #include typedef std::tuple TileBounds; + +#endif \ No newline at end of file diff --git a/vendor/gsplat-cpu/bindings.h b/vendor/gsplat-cpu/bindings.h new file mode 100644 index 0000000..bfcce48 --- /dev/null +++ b/vendor/gsplat-cpu/bindings.h @@ -0,0 +1,83 @@ +// Originally based on https://github.com/nerfstudio-project/gsplat +// This implementation has been substantially changed and optimized +// Licensed under the AGPLv3 +// Piero Toffanin - 2024 + +#include +#include +#include +#include +#include +#include + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +project_gaussians_forward_tensor_cpu( + const int num_points, + torch::Tensor &means3d, + torch::Tensor &scales, + const float glob_scale, + torch::Tensor &quats, + torch::Tensor &viewmat, + torch::Tensor &projmat, + const float fx, + const float fy, + const float cx, + const float cy, + const unsigned img_height, + const unsigned img_width, + const float clip_thresh +); + +std::tuple< + torch::Tensor, + torch::Tensor, + std::vector * +> rasterize_forward_tensor_cpu( + const int width, + const int height, + const torch::Tensor &xys, + const torch::Tensor &conics, + const torch::Tensor &colors, + const torch::Tensor &opacities, + const torch::Tensor &background, + const torch::Tensor &cov2d, + const torch::Tensor &camDepths +); + +std:: + tuple< + torch::Tensor, // dL_dxy + torch::Tensor, // dL_dconic + torch::Tensor, // dL_dcolors + torch::Tensor // dL_dopacity + > + rasterize_backward_tensor_cpu( + const int height, + const int width, + const torch::Tensor &xys, + const torch::Tensor &conics, + const torch::Tensor &colors, + const torch::Tensor &opacities, + const torch::Tensor &background, + const torch::Tensor &cov2d, + const torch::Tensor &camDepths, + const torch::Tensor &final_Ts, + const std::vector *px2gid, + const torch::Tensor &v_output, // dL_dout_color + const torch::Tensor &v_output_alpha + ); + +int numShBases(int degree); + +torch::Tensor compute_sh_forward_tensor_cpu( + const int num_points, + const int degree, + const int degrees_to_use, + const torch::Tensor &viewdirs, + const torch::Tensor &coeffs +); \ No newline at end of file diff --git a/vendor/gsplat-cpu/gsplat_cpu.cpp b/vendor/gsplat-cpu/gsplat_cpu.cpp new file mode 100644 index 0000000..250b394 --- /dev/null +++ b/vendor/gsplat-cpu/gsplat_cpu.cpp @@ -0,0 +1,483 @@ +// Originally started from https://github.com/nerfstudio-project/gsplat +// This implementation has been substantially changed and optimized +// Licensed under the AGPLv3 +// Piero Toffanin - 2024 + +#include "bindings.h" +#include "../gsplat/config.h" + +#include +#include +#include +#include + +using namespace torch::indexing; + +torch::Tensor quatToRot(const torch::Tensor &quat){ + auto u = torch::unbind(torch::nn::functional::normalize(quat, torch::nn::functional::NormalizeFuncOptions().dim(-1)), -1); + torch::Tensor w = u[0]; + torch::Tensor x = u[1]; + torch::Tensor y = u[2]; + torch::Tensor z = u[3]; + return torch::stack({ + torch::stack({ + 1.0 - 2.0 * (y.pow(2) + z.pow(2)), + 2.0 * (x * y - w * z), + 2.0 * (x * z + w * y) + }, -1), + torch::stack({ + 2.0 * (x * y + w * z), + 1.0 - 2.0 * (x.pow(2) + z.pow(2)), + 2.0 * (y * z - w * x) + }, -1), + torch::stack({ + 2.0 * (x * z - w * y), + 2.0 * (y * z + w * x), + 1.0 - 2.0 * (x.pow(2) + y.pow(2)) + }, -1) + }, -2); + +} + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +project_gaussians_forward_tensor_cpu( + const int num_points, + torch::Tensor &means3d, + torch::Tensor &scales, + const float glob_scale, + torch::Tensor &quats, + torch::Tensor &viewmat, + torch::Tensor &projmat, + const float fx, + const float fy, + const float cx, + const float cy, + const unsigned img_height, + const unsigned img_width, + const float clip_thresh +){ + float fovx = 0.5f * static_cast(img_height) / fx; + float fovy = 0.5f * static_cast(img_width) / fy; + + // clip_near_plane + torch::Tensor Rclip = viewmat.index({"...", Slice(None, 3), Slice(None, 3)}); + torch::Tensor Tclip = viewmat.index({"...", Slice(None, 3), 3}); + torch::Tensor pView = torch::matmul(Rclip, means3d.index({"...", None})).index({"...", 0}) + Tclip; + // torch::Tensor isClose = pView.index({"...", 2}) < clip_thresh; + + // scale_rot_to_cov3d + torch::Tensor R = quatToRot(quats); + torch::Tensor M = R * glob_scale * scales.index({"...", None, Slice()}); + torch::Tensor cov3d = torch::matmul(M, M.transpose(-1, -2)); + + // project_cov3d_ewa + torch::Tensor limX = 1.3f * torch::tensor({fovx}, means3d.device()); + torch::Tensor limY = 1.3f * torch::tensor({fovy}, means3d.device()); + + torch::Tensor minLimX = pView.index({"...", 2}) * torch::min(limX, torch::max(-limX, pView.index({"...", 0}) / pView.index({"...", 2}))); + torch::Tensor minLimY = pView.index({"...", 2}) * torch::min(limY, torch::max(-limY, pView.index({"...", 1}) / pView.index({"...", 2}))); + + torch::Tensor t = torch::cat({minLimX.index({"...", None}), minLimY.index({"...", None}), pView.index({"...", 2, None})}, -1); + torch::Tensor rz = 1.0f / t.index({"...", 2}); + torch::Tensor rz2 = rz.pow(2); + + torch::Tensor J = torch::stack({ + torch::stack({fx * rz, torch::zeros_like(rz), -fx * t.index({"...", 0}) * rz2}, -1), + torch::stack({torch::zeros_like(rz), fy * rz, -fy * t.index({"...", 1}) * rz2}, -1) + }, -2); + + torch::Tensor T = torch::matmul(J, Rclip); + torch::Tensor cov2d = torch::matmul(T, torch::matmul(cov3d, T.transpose(-1, -2))); + + // Add blur along axes + cov2d.index_put_({"...", 0, 0}, cov2d.index({"...", 0, 0}) + 0.3f); + cov2d.index_put_({"...", 1, 1}, cov2d.index({"...", 1, 1}) + 0.3f); + + // compute_cov2d_bounds + float eps = 1e-6f; + torch::Tensor det = cov2d.index({"...", 0, 0}) * cov2d.index({"...", 1, 1}) - cov2d.index({"...", 0, 1}).pow(2); + det = torch::clamp_min(det, eps); + torch::Tensor conic = torch::stack({ + cov2d.index({"...", 1, 1}) / det, + -cov2d.index({"...", 0, 1}) / det, + cov2d.index({"...", 0, 0}) / det + }, -1); + + torch::Tensor b = (cov2d.index({"...", 0, 0}) + cov2d.index({"...", 1, 1})) / 2.0f; + torch::Tensor sq = torch::sqrt(torch::clamp_min(b.pow(2) - det, 0.1f)); + torch::Tensor v1 = b + sq; + torch::Tensor v2 = b - sq; + torch::Tensor radius = torch::ceil(3.0f * torch::sqrt(torch::max(v1, v2))); + // torch::Tensor detValid = det > eps; + + // project_pix + torch::Tensor pHom = torch::nn::functional::pad(means3d, torch::nn::functional::PadFuncOptions({0, 1}).mode(torch::kConstant).value(1.0f)); + pHom = torch::einsum("...ij,...j->...i", {projmat, pHom}); + torch::Tensor rw = 1.0f / torch::clamp_min(pHom.index({"...", 3}), eps); + torch::Tensor pProj = pHom.index({"...", Slice(None, 3)}) * rw.index({"...", None}); + torch::Tensor u = 0.5f * ((pProj.index({"...", 0}) + 1.0f) * static_cast(img_width) - 1.0f); + torch::Tensor v = 0.5f * ((pProj.index({"...", 1}) + 1.0f) * static_cast(img_height) - 1.0f); + torch::Tensor xys = torch::stack({u, v}, -1); // center + + torch::Tensor radii = radius.to(torch::kInt32); + torch::Tensor camDepths = pProj.index({"...", 2}); + + return std::make_tuple(xys, radii, conic, cov2d, camDepths); +} + +std::tuple< + torch::Tensor, + torch::Tensor, + std::vector * +> rasterize_forward_tensor_cpu( + const int width, + const int height, + const torch::Tensor &xys, + const torch::Tensor &conics, + const torch::Tensor &colors, + const torch::Tensor &opacities, + const torch::Tensor &background, + const torch::Tensor &cov2d, + const torch::Tensor &camDepths +){ + torch::NoGradGuard noGrad; + + int channels = colors.size(1); + int numPoints = xys.size(0); + float *pDepths = static_cast(camDepths.data_ptr()); + std::vector *px2gid = new std::vector[width * height]; + + std::vector< size_t > gIndices( numPoints ); + std::iota( gIndices.begin(), gIndices.end(), 0 ); + std::sort(gIndices.begin(), gIndices.end(), [&pDepths](int a, int b){ + return pDepths[a] < pDepths[b]; + }); + + torch::Device device = xys.device(); + + torch::Tensor outImg = torch::zeros({height, width, channels}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); + torch::Tensor finalTs = torch::ones({height, width}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); + torch::Tensor done = torch::zeros({height, width}, torch::TensorOptions().dtype(torch::kBool).device(device)); + + torch::Tensor sqCov2dX = 3.0f * torch::sqrt(cov2d.index({"...", 0, 0})); + torch::Tensor sqCov2dY = 3.0f * torch::sqrt(cov2d.index({"...", 1, 1})); + + float *pConics = static_cast(conics.data_ptr()); + float *pCenters = static_cast(xys.data_ptr()); + float *pSqCov2dX = static_cast(sqCov2dX.data_ptr()); + float *pSqCov2dY = static_cast(sqCov2dY.data_ptr()); + float *pOpacities = static_cast(opacities.data_ptr()); + + float *pOutImg = static_cast(outImg.data_ptr()); + float *pFinalTs = static_cast(finalTs.data_ptr()); + bool *pDone = static_cast(done.data_ptr()); + + float *pColors = static_cast(colors.data_ptr()); + + float bgX = background[0].item(); + float bgY = background[1].item(); + float bgZ = background[2].item(); + + const float alphaThresh = 1.0f / 255.0f; + + for (int idx = 0; idx < numPoints; idx++){ + int32_t gaussianId = gIndices[idx]; + + float A = pConics[gaussianId * 3 + 0]; + float B = pConics[gaussianId * 3 + 1]; + float C = pConics[gaussianId * 3 + 2]; + + float gX = pCenters[gaussianId * 2 + 0]; + float gY = pCenters[gaussianId * 2 + 1]; + + float sqx = pSqCov2dX[gaussianId]; + float sqy = pSqCov2dY[gaussianId]; + + int minx = (std::max)(0, static_cast(std::floor(gY - sqy)) - 2); + int maxx = (std::min)(height, static_cast(std::ceil(gY + sqy)) + 2); + int miny = (std::max)(0, static_cast(std::floor(gX - sqx)) - 2); + int maxy = (std::min)(width, static_cast(std::ceil(gX + sqx)) + 2); + + for (int i = minx; i < maxx; i++){ + for (int j = miny; j < maxy; j++){ + size_t pixIdx = (i * width + j); + if (pDone[pixIdx]) continue; + + float xCam = gX - j; + float yCam = gY - i; + float sigma = ( + 0.5f + * (A * xCam * xCam + C * yCam * yCam) + + B * xCam * yCam + ); + + if (sigma < 0.0f) continue; + float alpha = (std::min)(0.999f, (pOpacities[gaussianId] * std::exp(-sigma))); + if (alpha < alphaThresh) continue; + + float T = pFinalTs[pixIdx]; + float nextT = T * (1.0f - alpha); + if (nextT <= 1e-4f) { // this pixel is done + pDone[pixIdx] = true; + continue; + } + + float vis = alpha * T; + + pOutImg[pixIdx * 3 + 0] += vis * pColors[gaussianId * 3 + 0]; + pOutImg[pixIdx * 3 + 1] += vis * pColors[gaussianId * 3 + 1]; + pOutImg[pixIdx * 3 + 2] += vis * pColors[gaussianId * 3 + 2]; + + pFinalTs[pixIdx] = nextT; + px2gid[pixIdx].push_back(gaussianId); + } + } + } + + // Background + for (int i = 0; i < height; i++){ + for (int j = 0; j < width; j++){ + size_t pixIdx = (i * width + j); + float T = pFinalTs[pixIdx]; + + pOutImg[pixIdx * 3 + 0] += T * bgX; + pOutImg[pixIdx * 3 + 1] += T * bgY; + pOutImg[pixIdx * 3 + 2] += T * bgZ; + + std::reverse(px2gid[pixIdx].begin(), px2gid[pixIdx].end()); + } + } + + return std::make_tuple(outImg, finalTs, px2gid); +} + + +std:: + tuple< + torch::Tensor, // dL_dxy + torch::Tensor, // dL_dconic + torch::Tensor, // dL_dcolors + torch::Tensor // dL_dopacity + > + rasterize_backward_tensor_cpu( + const int height, + const int width, + const torch::Tensor &xys, + const torch::Tensor &conics, + const torch::Tensor &colors, + const torch::Tensor &opacities, + const torch::Tensor &background, + const torch::Tensor &cov2d, + const torch::Tensor &camDepths, + const torch::Tensor &final_Ts, + const std::vector *px2gid, + const torch::Tensor &v_output, // dL_dout_color + const torch::Tensor &v_output_alpha + ){ + torch::NoGradGuard noGrad; + + int numPoints = xys.size(0); + int channels = colors.size(1); + torch::Device device = xys.device(); + + torch::Tensor v_xy = torch::zeros({numPoints, 2}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); + torch::Tensor v_conic = torch::zeros({numPoints, 3}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); + torch::Tensor v_colors = torch::zeros({numPoints, channels}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); + torch::Tensor v_opacity = torch::zeros({numPoints, 1}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); + + float *pv_xy = static_cast(v_xy.data_ptr()); + float *pv_conic = static_cast(v_conic.data_ptr()); + float *pv_colors = static_cast(v_colors.data_ptr()); + float *pv_opacity = static_cast(v_opacity.data_ptr()); + + float *pColors = static_cast(colors.data_ptr()); + float *pv_output = static_cast(v_output.data_ptr()); + float *pv_outputAlpha = static_cast(v_output_alpha.data_ptr()); + float *pConics = static_cast(conics.data_ptr()); + float *pCenters = static_cast(xys.data_ptr()); + float *pOpacities = static_cast(opacities.data_ptr()); + + float bgX = background[0].item(); + float bgY = background[1].item(); + float bgZ = background[2].item(); + + float *pFinalTs = static_cast(final_Ts.data_ptr()); + + const float alphaThresh = 1.0f / 255.0f; + + for (int i = 0; i < height; i++){ + for (int j = 0; j < width; j++){ + size_t pixIdx = (i * width + j); + float Tfinal = pFinalTs[pixIdx]; + float T = Tfinal; + float buffer[3] = {0.0f, 0.0f, 0.0f}; + + for (const int32_t &gaussianId : px2gid[pixIdx]){ + float A = pConics[gaussianId * 3 + 0]; + float B = pConics[gaussianId * 3 + 1]; + float C = pConics[gaussianId * 3 + 2]; + + float gX = pCenters[gaussianId * 2 + 0]; + float gY = pCenters[gaussianId * 2 + 1]; + + float xCam = gX - j; + float yCam = gY - i; + float sigma = ( + 0.5f + * (A * xCam * xCam + C * yCam * yCam) + + B * xCam * yCam + ); + + if (sigma < 0.0f) continue; + float vis = std::exp(-sigma); + float alpha = (std::min)(0.99f, pOpacities[gaussianId] * vis); + if (alpha < alphaThresh) continue; + + float ra = 1.0f / (1.0f - alpha); + T *= ra; + float fac = alpha * T; + + pv_colors[gaussianId * 3 + 0] += fac * pv_output[pixIdx * 3 + 0]; + pv_colors[gaussianId * 3 + 1] += fac * pv_output[pixIdx * 3 + 1]; + pv_colors[gaussianId * 3 + 2] += fac * pv_output[pixIdx * 3 + 2]; + + float v_alpha = ((pColors[gaussianId * 3 + 0] * T - buffer[0] * ra) * pv_output[pixIdx * 3 + 0]) + + ((pColors[gaussianId * 3 + 1] * T - buffer[1] * ra) * pv_output[pixIdx * 3 + 1]) + + ((pColors[gaussianId * 3 + 2] * T - buffer[2] * ra) * pv_output[pixIdx * 3 + 2]) + + (Tfinal * ra * pv_outputAlpha[pixIdx]) + + + (-Tfinal * ra * bgX * pv_output[pixIdx * 3 + 0]) + + (-Tfinal * ra * bgY * pv_output[pixIdx * 3 + 1]) + + (-Tfinal * ra * bgZ * pv_output[pixIdx * 3 + 2]); + + buffer[0] += pColors[gaussianId * 3 + 0] * fac; + buffer[1] += pColors[gaussianId * 3 + 1] * fac; + buffer[2] += pColors[gaussianId * 3 + 2] * fac; + + float v_sigma = -pOpacities[gaussianId] * vis * v_alpha; + pv_conic[gaussianId * 3 + 0] += 0.5f * v_sigma * xCam * xCam; + pv_conic[gaussianId * 3 + 1] += 0.5f * v_sigma * xCam * yCam; + pv_conic[gaussianId * 3 + 2] += 0.5f * v_sigma * yCam * yCam; + + pv_xy[gaussianId * 2 + 0] += v_sigma * (A * xCam + B * yCam); + pv_xy[gaussianId * 2 + 1] += v_sigma * (B * xCam + C * yCam); + + pv_opacity[gaussianId] += vis * v_alpha; + } + } + } + + return std::make_tuple(v_xy, v_conic, v_colors, v_opacity); +} + + +const float SH_C0 = 0.28209479177387814f; +const float SH_C1 = 0.4886025119029199f; +const float SH_C2[] = { + 1.0925484305920792f, + -1.0925484305920792f, + 0.31539156525252005f, + -1.0925484305920792f, + 0.5462742152960396f +}; +const float SH_C3[] = { + -0.5900435899266435f, + 2.890611442640554f, + -0.4570457994644658f, + 0.3731763325901154f, + -0.4570457994644658f, + 1.445305721320277f, + -0.5900435899266435f +}; +const float SH_C4[] = { + 2.5033429417967046f, + -1.7701307697799304f, + 0.9461746957575601f, + -0.6690465435572892f, + 0.10578554691520431f, + -0.6690465435572892f, + 0.47308734787878004f, + -1.7701307697799304f, + 0.6258357354491761f +}; + +int numShBases(int degree){ + switch(degree){ + case 0: + return 1; + case 1: + return 4; + case 2: + return 9; + case 3: + return 16; + default: + return 25; + } +} + +torch::Tensor compute_sh_forward_tensor_cpu( + const int num_points, + const int degree, + const int degrees_to_use, + const torch::Tensor &viewdirs, + const torch::Tensor &coeffs +) { + const int numChannels = 3; + unsigned numBases = numShBases(degrees_to_use); + + torch::Tensor result = torch::zeros({viewdirs.size(0), numBases}, torch::TensorOptions().dtype(torch::kFloat32).device(viewdirs.device())); + + result.index_put_({"...", 0}, SH_C0); + if (numBases > 1){ + std::vector xyz = viewdirs.unbind(-1); + torch::Tensor x = xyz[0]; + torch::Tensor y = xyz[1]; + torch::Tensor z = xyz[2]; + + if (numBases > 4){ + torch::Tensor xx = x * x; + torch::Tensor yy = y * y; + torch::Tensor zz = z * z; + torch::Tensor xy = x * y; + torch::Tensor yz = y * z; + torch::Tensor xz = x * z; + + result.index_put_({"...", 4}, SH_C2[0] * xy); + result.index_put_({"...", 5}, SH_C2[1] * yz); + result.index_put_({"...", 6}, SH_C2[2] * (2.0f * zz - xx - yy)); + result.index_put_({"...", 7}, SH_C2[3] * xz); + result.index_put_({"...", 8}, SH_C2[4] * (xx - yy)); + + if (numBases > 9){ + result.index_put_({"...", 9}, SH_C3[0] * y * (3 * xx - yy)); + result.index_put_({"...", 10}, SH_C3[1] * xy * z); + result.index_put_({"...", 11}, SH_C3[2] * y * (4 * zz - xx - yy)); + result.index_put_({"...", 12}, SH_C3[3] * z * (2 * zz - 3 * xx - 3 * yy)); + result.index_put_({"...", 13}, SH_C3[4] * x * (4 * zz - xx - yy) ); + result.index_put_({"...", 14}, SH_C3[5] * z * (xx - yy)); + result.index_put_({"...", 15}, SH_C3[6] * x * (xx - 3 * yy)); + + if (numBases > 16){ + result.index_put_({"...", 16}, SH_C4[0] * xy * (xx - yy)); + result.index_put_({"...", 17}, SH_C4[1] * yz * (3 * xx - yy)); + result.index_put_({"...", 18}, SH_C4[2] * xy * (7 * zz - 1)); + result.index_put_({"...", 19}, SH_C4[3] * yz * (7 * zz - 3)); + result.index_put_({"...", 20}, SH_C4[4] * (zz * (35 * zz - 30) + 3)); + result.index_put_({"...", 21}, SH_C4[5] * xz * (7 * zz - 3)); + result.index_put_({"...", 22}, SH_C4[6] * (xx - yy) * (7 * zz - 1)); + result.index_put_({"...", 23}, SH_C4[7] * xz * (xx - 3 * yy)); + result.index_put_({"...", 24}, SH_C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))); + + } + } + } + } + + return (result.index({"...", None}) * coeffs).sum(-2); +} \ No newline at end of file