From 25be7a323b4cac063278ff24e176fa808336a2b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ruilong=20Li=28=E6=9D=8E=E7=91=9E=E9=BE=99=29?= Date: Wed, 31 Jan 2024 13:16:16 -0800 Subject: [PATCH] Fix rendering bug (#112) * fix get_tile_bin_edges * ndc2pix -0.5 to +0.5 * format --- gsplat/_torch_impl.py | 9 ++++++--- gsplat/cuda/csrc/bindings.cu | 6 ++++-- gsplat/cuda/csrc/bindings.h | 3 ++- gsplat/cuda/csrc/helpers.cuh | 2 +- gsplat/utils.py | 13 +++++++++---- tests/test_get_tile_bin_edges.py | 6 ++++-- 6 files changed, 26 insertions(+), 13 deletions(-) diff --git a/gsplat/_torch_impl.py b/gsplat/_torch_impl.py index 08c2f4701..513116a15 100644 --- a/gsplat/_torch_impl.py +++ b/gsplat/_torch_impl.py @@ -1,7 +1,8 @@ """Pure PyTorch implementations of various functions""" +import struct + import torch import torch.nn.functional as F -import struct from jaxtyping import Float from torch import Tensor @@ -325,9 +326,11 @@ def map_gaussian_to_intersects( return isect_ids, gaussian_ids -def get_tile_bin_edges(num_intersects, isect_ids_sorted): +def get_tile_bin_edges(num_intersects, isect_ids_sorted, tile_bounds): tile_bins = torch.zeros( - (num_intersects, 2), dtype=torch.int32, device=isect_ids_sorted.device + (tile_bounds[0] * tile_bounds[1], 2), + dtype=torch.int32, + device=isect_ids_sorted.device, ) for idx in range(num_intersects): diff --git a/gsplat/cuda/csrc/bindings.cu b/gsplat/cuda/csrc/bindings.cu index 237a55f8c..6ade94cf3 100644 --- a/gsplat/cuda/csrc/bindings.cu +++ b/gsplat/cuda/csrc/bindings.cu @@ -311,11 +311,13 @@ std::tuple map_gaussian_to_intersects_tensor( } torch::Tensor get_tile_bin_edges_tensor( - int num_intersects, const torch::Tensor &isect_ids_sorted + int num_intersects, const torch::Tensor &isect_ids_sorted, + const std::tuple tile_bounds ) { CHECK_INPUT(isect_ids_sorted); + int num_tiles = std::get<0>(tile_bounds) * std::get<1>(tile_bounds); torch::Tensor tile_bins = torch::zeros( - {num_intersects, 2}, isect_ids_sorted.options().dtype(torch::kInt32) + {num_tiles, 2}, isect_ids_sorted.options().dtype(torch::kInt32) ); get_tile_bin_edges<<< (num_intersects + N_THREADS - 1) / N_THREADS, diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index 60a917028..d29a0b0b1 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -100,7 +100,8 @@ std::tuple map_gaussian_to_intersects_tensor( torch::Tensor get_tile_bin_edges_tensor( int num_intersects, - const torch::Tensor &isect_ids_sorted + const torch::Tensor &isect_ids_sorted, + const std::tuple tile_bounds ); std::tuple< diff --git a/gsplat/cuda/csrc/helpers.cuh b/gsplat/cuda/csrc/helpers.cuh index 423733395..eb748b535 100644 --- a/gsplat/cuda/csrc/helpers.cuh +++ b/gsplat/cuda/csrc/helpers.cuh @@ -5,7 +5,7 @@ #include inline __device__ float ndc2pix(const float x, const float W, const float cx) { - return 0.5f * W * x + cx - 0.5; + return 0.5f * W * x + cx + 0.5f; } inline __device__ void get_bbox( diff --git a/gsplat/utils.py b/gsplat/utils.py index d7d782ff8..543848a5b 100644 --- a/gsplat/utils.py +++ b/gsplat/utils.py @@ -2,9 +2,9 @@ from typing import Tuple +import torch from jaxtyping import Float, Int from torch import Tensor -import torch import gsplat.cuda as _C @@ -51,7 +51,9 @@ def map_gaussian_to_intersects( def get_tile_bin_edges( - num_intersects: int, isect_ids_sorted: Int[Tensor, "num_intersects 1"] + num_intersects: int, + isect_ids_sorted: Int[Tensor, "num_intersects 1"], + tile_bounds: Tuple[int, int, int], ) -> Int[Tensor, "num_intersects 2"]: """Map sorted intersection IDs to tile bins which give the range of unique gaussian IDs belonging to each tile. @@ -65,13 +67,16 @@ def get_tile_bin_edges( Args: num_intersects (int): total number of gaussian intersects. isect_ids_sorted (Tensor): sorted unique IDs for each gaussian in the form (tile | depth id). + tile_bounds (Tuple): tile dimensions as a len 3 tuple (tiles.x , tiles.y, 1). Returns: A Tensor: - **tile_bins** (Tensor): range of gaussians IDs hit per tile. """ - return _C.get_tile_bin_edges(num_intersects, isect_ids_sorted.contiguous()) + return _C.get_tile_bin_edges( + num_intersects, isect_ids_sorted.contiguous(), tile_bounds + ) def compute_cov2d_bounds( @@ -163,5 +168,5 @@ def bin_and_sort_gaussians( ) isect_ids_sorted, sorted_indices = torch.sort(isect_ids) gaussian_ids_sorted = torch.gather(gaussian_ids, 0, sorted_indices) - tile_bins = get_tile_bin_edges(num_intersects, isect_ids_sorted) + tile_bins = get_tile_bin_edges(num_intersects, isect_ids_sorted, tile_bounds) return isect_ids, gaussian_ids, isect_ids_sorted, gaussian_ids_sorted, tile_bins diff --git a/tests/test_get_tile_bin_edges.py b/tests/test_get_tile_bin_edges.py index ffd995d7c..d72eddb73 100644 --- a/tests/test_get_tile_bin_edges.py +++ b/tests/test_get_tile_bin_edges.py @@ -75,8 +75,10 @@ def test_get_tile_bin_edges(): _isect_ids_sorted = sorted_values _gaussian_ids_sorted = torch.gather(_gaussian_ids_unsorted, 0, sorted_indices) - _tile_bins = _torch_impl.get_tile_bin_edges(_num_intersects, _isect_ids_sorted) - tile_bins = get_tile_bin_edges(_num_intersects, _isect_ids_sorted) + _tile_bins = _torch_impl.get_tile_bin_edges( + _num_intersects, _isect_ids_sorted, tile_bounds + ) + tile_bins = get_tile_bin_edges(_num_intersects, _isect_ids_sorted, tile_bounds) torch.testing.assert_close(_tile_bins, tile_bins)