Skip to content

Commit

Permalink
Fix rendering bug (#112)
Browse files Browse the repository at this point in the history
* fix get_tile_bin_edges

* ndc2pix -0.5 to +0.5

* format
  • Loading branch information
liruilong940607 authored Jan 31, 2024
1 parent 2e6ae0e commit 25be7a3
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 13 deletions.
9 changes: 6 additions & 3 deletions gsplat/_torch_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Pure PyTorch implementations of various functions"""
import struct

import torch
import torch.nn.functional as F
import struct
from jaxtyping import Float
from torch import Tensor

Expand Down Expand Up @@ -325,9 +326,11 @@ def map_gaussian_to_intersects(
return isect_ids, gaussian_ids


def get_tile_bin_edges(num_intersects, isect_ids_sorted):
def get_tile_bin_edges(num_intersects, isect_ids_sorted, tile_bounds):
tile_bins = torch.zeros(
(num_intersects, 2), dtype=torch.int32, device=isect_ids_sorted.device
(tile_bounds[0] * tile_bounds[1], 2),
dtype=torch.int32,
device=isect_ids_sorted.device,
)

for idx in range(num_intersects):
Expand Down
6 changes: 4 additions & 2 deletions gsplat/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,13 @@ std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(
}

torch::Tensor get_tile_bin_edges_tensor(
int num_intersects, const torch::Tensor &isect_ids_sorted
int num_intersects, const torch::Tensor &isect_ids_sorted,
const std::tuple<int, int, int> tile_bounds
) {
CHECK_INPUT(isect_ids_sorted);
int num_tiles = std::get<0>(tile_bounds) * std::get<1>(tile_bounds);
torch::Tensor tile_bins = torch::zeros(
{num_intersects, 2}, isect_ids_sorted.options().dtype(torch::kInt32)
{num_tiles, 2}, isect_ids_sorted.options().dtype(torch::kInt32)
);
get_tile_bin_edges<<<
(num_intersects + N_THREADS - 1) / N_THREADS,
Expand Down
3 changes: 2 additions & 1 deletion gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(

torch::Tensor get_tile_bin_edges_tensor(
int num_intersects,
const torch::Tensor &isect_ids_sorted
const torch::Tensor &isect_ids_sorted,
const std::tuple<int, int, int> tile_bounds
);

std::tuple<
Expand Down
2 changes: 1 addition & 1 deletion gsplat/cuda/csrc/helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <iostream>

inline __device__ float ndc2pix(const float x, const float W, const float cx) {
return 0.5f * W * x + cx - 0.5;
return 0.5f * W * x + cx + 0.5f;
}

inline __device__ void get_bbox(
Expand Down
13 changes: 9 additions & 4 deletions gsplat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from typing import Tuple

import torch
from jaxtyping import Float, Int
from torch import Tensor
import torch

import gsplat.cuda as _C

Expand Down Expand Up @@ -51,7 +51,9 @@ def map_gaussian_to_intersects(


def get_tile_bin_edges(
num_intersects: int, isect_ids_sorted: Int[Tensor, "num_intersects 1"]
num_intersects: int,
isect_ids_sorted: Int[Tensor, "num_intersects 1"],
tile_bounds: Tuple[int, int, int],
) -> Int[Tensor, "num_intersects 2"]:
"""Map sorted intersection IDs to tile bins which give the range of unique gaussian IDs belonging to each tile.
Expand All @@ -65,13 +67,16 @@ def get_tile_bin_edges(
Args:
num_intersects (int): total number of gaussian intersects.
isect_ids_sorted (Tensor): sorted unique IDs for each gaussian in the form (tile | depth id).
tile_bounds (Tuple): tile dimensions as a len 3 tuple (tiles.x , tiles.y, 1).
Returns:
A Tensor:
- **tile_bins** (Tensor): range of gaussians IDs hit per tile.
"""
return _C.get_tile_bin_edges(num_intersects, isect_ids_sorted.contiguous())
return _C.get_tile_bin_edges(
num_intersects, isect_ids_sorted.contiguous(), tile_bounds
)


def compute_cov2d_bounds(
Expand Down Expand Up @@ -163,5 +168,5 @@ def bin_and_sort_gaussians(
)
isect_ids_sorted, sorted_indices = torch.sort(isect_ids)
gaussian_ids_sorted = torch.gather(gaussian_ids, 0, sorted_indices)
tile_bins = get_tile_bin_edges(num_intersects, isect_ids_sorted)
tile_bins = get_tile_bin_edges(num_intersects, isect_ids_sorted, tile_bounds)
return isect_ids, gaussian_ids, isect_ids_sorted, gaussian_ids_sorted, tile_bins
6 changes: 4 additions & 2 deletions tests/test_get_tile_bin_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ def test_get_tile_bin_edges():
_isect_ids_sorted = sorted_values
_gaussian_ids_sorted = torch.gather(_gaussian_ids_unsorted, 0, sorted_indices)

_tile_bins = _torch_impl.get_tile_bin_edges(_num_intersects, _isect_ids_sorted)
tile_bins = get_tile_bin_edges(_num_intersects, _isect_ids_sorted)
_tile_bins = _torch_impl.get_tile_bin_edges(
_num_intersects, _isect_ids_sorted, tile_bounds
)
tile_bins = get_tile_bin_edges(_num_intersects, _isect_ids_sorted, tile_bounds)

torch.testing.assert_close(_tile_bins, tile_bins)

Expand Down

0 comments on commit 25be7a3

Please sign in to comment.