diff --git a/cuda_rasterizer/auxiliary.h b/cuda_rasterizer/auxiliary.h index 4d4b9b78..12dd95ce 100644 --- a/cuda_rasterizer/auxiliary.h +++ b/cuda_rasterizer/auxiliary.h @@ -146,10 +146,10 @@ __forceinline__ __device__ bool in_frustum(int idx, float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; // Bring points to screen space - float4 p_hom = transformPoint4x4(p_orig, projmatrix); + p_view = transformPoint4x3(p_orig, viewmatrix); + float4 p_hom = transformPoint4x4(p_view, projmatrix); float p_w = 1.0f / (p_hom.w + 0.0000001f); float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; - p_view = transformPoint4x3(p_orig, viewmatrix); if (p_view.z <= 0.2f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3))) { diff --git a/cuda_rasterizer/backward.cu b/cuda_rasterizer/backward.cu index 4aa41e1c..24e6db56 100644 --- a/cuda_rasterizer/backward.cu +++ b/cuda_rasterizer/backward.cu @@ -13,6 +13,7 @@ #include "auxiliary.h" #include #include +#include namespace cg = cooperative_groups; // Backward pass for conversion of spherical harmonics to RGB for @@ -145,9 +146,9 @@ __global__ void computeCov2DCUDA(int P, const float3* means, const int* radii, const float* cov3Ds, - const float h_x, float h_y, - const float tan_fovx, float tan_fovy, + const float width, float height, const float* view_matrix, + const float* projmatrix, const float* dL_dconics, float3* dL_dmeans, float* dL_dcov) @@ -156,6 +157,9 @@ __global__ void computeCov2DCUDA(int P, if (idx >= P || !(radii[idx] > 0)) return; + const float h_x = projmatrix[0] * width / 2.0f; + const float h_y = projmatrix[5] * height / 2.0f; + // Reading location of 3D covariance for this Gaussian const float* cov3D = cov3Ds + 6 * idx; @@ -164,16 +168,19 @@ __global__ void computeCov2DCUDA(int P, float3 mean = means[idx]; float3 dL_dconic = { dL_dconics[4 * idx], dL_dconics[4 * idx + 1], dL_dconics[4 * idx + 3] }; float3 t = transformPoint4x3(mean, view_matrix); - - const float limx = 1.3f * tan_fovx; - const float limy = 1.3f * tan_fovy; + + float fx = projmatrix[0], fy = projmatrix[5], cx = projmatrix[8], cy = projmatrix[9]; + const float xmin = (-1.3f - cx) / fx; + const float xmax = (1.3f - cx) / fx; + const float ymin = (-1.3f - cy) / fy; + const float ymax = (1.3f - cy) / fy; const float txtz = t.x / t.z; const float tytz = t.y / t.z; - t.x = min(limx, max(-limx, txtz)) * t.z; - t.y = min(limy, max(-limy, tytz)) * t.z; + t.x = min(xmax, max(xmin, txtz)) * t.z; + t.y = min(ymax, max(ymin, tytz)) * t.z; - const float x_grad_mul = txtz < -limx || txtz > limx ? 0 : 1; - const float y_grad_mul = tytz < -limy || tytz > limy ? 0 : 1; + const float x_grad_mul = txtz < xmin || txtz > xmax ? 0 : 1; + const float y_grad_mul = tytz < ymin || tytz > ymax ? 0 : 1; glm::mat3 J = glm::mat3(h_x / t.z, 0.0f, -(h_x * t.x) / (t.z * t.z), 0.0f, h_y / t.z, -(h_y * t.y) / (t.z * t.z), @@ -353,7 +360,8 @@ __global__ void preprocessCUDA( const glm::vec3* scales, const glm::vec4* rotations, const float scale_modifier, - const float* proj, + const float* view_matrix, + const float* proj_matrix, const glm::vec3* campos, const float3* dL_dmean2D, glm::vec3* dL_dmeans, @@ -370,9 +378,13 @@ __global__ void preprocessCUDA( float3 m = means[idx]; // Taking care of gradients from the screenspace points - float4 m_hom = transformPoint4x4(m, proj); + float3 view_point = transformPoint4x3(m, view_matrix); + float4 m_hom = transformPoint4x4(view_point, proj_matrix); float m_w = 1.0f / (m_hom.w + 0.0000001f); + glm::mat4x4 full_proj_matrix = glm::make_mat4(proj_matrix) * glm::make_mat4(view_matrix); + float* proj = (float*)glm::value_ptr(full_proj_matrix); + // Compute loss gradient w.r.t. 3D means due to gradients of 2D means // from rendering procedure glm::vec3 dL_dmean; @@ -568,8 +580,7 @@ void BACKWARD::preprocess( const float* cov3Ds, const float* viewmatrix, const float* projmatrix, - const float focal_x, float focal_y, - const float tan_fovx, float tan_fovy, + const float width, float height, const glm::vec3* campos, const float3* dL_dmean2D, const float* dL_dconic, @@ -589,11 +600,9 @@ void BACKWARD::preprocess( means3D, radii, cov3Ds, - focal_x, - focal_y, - tan_fovx, - tan_fovy, + width, height, viewmatrix, + projmatrix, dL_dconic, (float3*)dL_dmean3D, dL_dcov3D); @@ -610,6 +619,7 @@ void BACKWARD::preprocess( (glm::vec3*)scales, (glm::vec4*)rotations, scale_modifier, + viewmatrix, projmatrix, campos, (float3*)dL_dmean2D, diff --git a/cuda_rasterizer/backward.h b/cuda_rasterizer/backward.h index 93dd2e4b..edd685f4 100644 --- a/cuda_rasterizer/backward.h +++ b/cuda_rasterizer/backward.h @@ -49,8 +49,7 @@ namespace BACKWARD const float* cov3Ds, const float* view, const float* proj, - const float focal_x, float focal_y, - const float tan_fovx, float tan_fovy, + const float width, float height, const glm::vec3* campos, const float3* dL_dmean2D, const float* dL_dconics, diff --git a/cuda_rasterizer/forward.cu b/cuda_rasterizer/forward.cu index c419a328..c6019fd7 100644 --- a/cuda_rasterizer/forward.cu +++ b/cuda_rasterizer/forward.cu @@ -71,7 +71,7 @@ __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const } // Forward version of 2D covariance matrix computation -__device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float* cov3D, const float* viewmatrix) +__device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, const float* cov3D, const float* viewmatrix, const float* projmatrix) { // The following models the steps outlined by equations 29 // and 31 in "EWA Splatting" (Zwicker et al., 2002). @@ -79,12 +79,15 @@ __device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, // Transposes used to account for row-/column-major conventions. float3 t = transformPoint4x3(mean, viewmatrix); - const float limx = 1.3f * tan_fovx; - const float limy = 1.3f * tan_fovy; + float fx = projmatrix[0], fy = projmatrix[5], cx = projmatrix[8], cy = projmatrix[9]; + const float xmin = (-1.3f - cx) / fx; + const float xmax = (1.3f - cx) / fx; + const float ymin = (-1.3f - cy) / fy; + const float ymax = (1.3f - cy) / fy; const float txtz = t.x / t.z; const float tytz = t.y / t.z; - t.x = min(limx, max(-limx, txtz)) * t.z; - t.y = min(limy, max(-limy, tytz)) * t.z; + t.x = min(xmax, max(xmin, txtz)) * t.z; + t.y = min(ymax, max(ymin, tytz)) * t.z; glm::mat3 J = glm::mat3( focal_x / t.z, 0.0f, -(focal_x * t.x) / (t.z * t.z), @@ -167,8 +170,6 @@ __global__ void preprocessCUDA(int P, int D, int M, const float* projmatrix, const glm::vec3* cam_pos, const int W, int H, - const float tan_fovx, float tan_fovy, - const float focal_x, float focal_y, int* radii, float2* points_xy_image, float* depths, @@ -183,6 +184,9 @@ __global__ void preprocessCUDA(int P, int D, int M, if (idx >= P) return; + const float focal_x = projmatrix[0] * W / 2.0f; + const float focal_y = projmatrix[5] * H / 2.0f; + // Initialize radius and touched tiles to 0. If this isn't changed, // this Gaussian will not be processed further. radii[idx] = 0; @@ -195,7 +199,7 @@ __global__ void preprocessCUDA(int P, int D, int M, // Transform point by projecting float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] }; - float4 p_hom = transformPoint4x4(p_orig, projmatrix); + float4 p_hom = transformPoint4x4(p_view, projmatrix); float p_w = 1.0f / (p_hom.w + 0.0000001f); float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; @@ -213,7 +217,7 @@ __global__ void preprocessCUDA(int P, int D, int M, } // Compute 2D screen-space covariance matrix - float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix); + float3 cov = computeCov2D(p_orig, focal_x, focal_y, cov3D, viewmatrix, projmatrix); // Invert covariance (EWA algorithm) float det = (cov.x * cov.z - cov.y * cov.y); @@ -413,8 +417,6 @@ void FORWARD::preprocess(int P, int D, int M, const float* projmatrix, const glm::vec3* cam_pos, const int W, int H, - const float focal_x, float focal_y, - const float tan_fovx, float tan_fovy, int* radii, float2* means2D, float* depths, @@ -440,8 +442,6 @@ void FORWARD::preprocess(int P, int D, int M, projmatrix, cam_pos, W, H, - tan_fovx, tan_fovy, - focal_x, focal_y, radii, means2D, depths, diff --git a/cuda_rasterizer/forward.h b/cuda_rasterizer/forward.h index 3c11cb91..0330955b 100644 --- a/cuda_rasterizer/forward.h +++ b/cuda_rasterizer/forward.h @@ -35,8 +35,6 @@ namespace FORWARD const float* projmatrix, const glm::vec3* cam_pos, const int W, int H, - const float focal_x, float focal_y, - const float tan_fovx, float tan_fovy, int* radii, float2* points_xy_image, float* depths, diff --git a/cuda_rasterizer/rasterizer.h b/cuda_rasterizer/rasterizer.h index 81544ef6..c33f8897 100644 --- a/cuda_rasterizer/rasterizer.h +++ b/cuda_rasterizer/rasterizer.h @@ -46,7 +46,6 @@ namespace CudaRasterizer const float* viewmatrix, const float* projmatrix, const float* cam_pos, - const float tan_fovx, float tan_fovy, const bool prefiltered, float* out_color, int* radii = nullptr, @@ -66,7 +65,6 @@ namespace CudaRasterizer const float* viewmatrix, const float* projmatrix, const float* campos, - const float tan_fovx, float tan_fovy, const int* radii, char* geom_buffer, char* binning_buffer, diff --git a/cuda_rasterizer/rasterizer_impl.cu b/cuda_rasterizer/rasterizer_impl.cu index f8782ac4..7f0f7a37 100644 --- a/cuda_rasterizer/rasterizer_impl.cu +++ b/cuda_rasterizer/rasterizer_impl.cu @@ -213,15 +213,11 @@ int CudaRasterizer::Rasterizer::forward( const float* viewmatrix, const float* projmatrix, const float* cam_pos, - const float tan_fovx, float tan_fovy, const bool prefiltered, float* out_color, int* radii, bool debug) { - const float focal_y = height / (2.0f * tan_fovy); - const float focal_x = width / (2.0f * tan_fovx); - size_t chunk_size = required(P); char* chunkptr = geometryBuffer(chunk_size); GeometryState geomState = GeometryState::fromChunk(chunkptr, P); @@ -259,8 +255,6 @@ int CudaRasterizer::Rasterizer::forward( viewmatrix, projmatrix, (glm::vec3*)cam_pos, width, height, - focal_x, focal_y, - tan_fovx, tan_fovy, radii, geomState.means2D, geomState.depths, @@ -351,7 +345,6 @@ void CudaRasterizer::Rasterizer::backward( const float* viewmatrix, const float* projmatrix, const float* campos, - const float tan_fovx, float tan_fovy, const int* radii, char* geom_buffer, char* binning_buffer, @@ -377,8 +370,8 @@ void CudaRasterizer::Rasterizer::backward( radii = geomState.internal_radii; } - const float focal_y = height / (2.0f * tan_fovy); - const float focal_x = width / (2.0f * tan_fovx); + const float focal_y = 0.f; + const float focal_x = 0.f; const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); const dim3 block(BLOCK_X, BLOCK_Y, 1); @@ -420,8 +413,7 @@ void CudaRasterizer::Rasterizer::backward( cov3D_ptr, viewmatrix, projmatrix, - focal_x, focal_y, - tan_fovx, tan_fovy, + width, height, (glm::vec3*)campos, (float3*)dL_dmean2D, dL_dconic, diff --git a/diff_gaussian_rasterization/__init__.py b/diff_gaussian_rasterization/__init__.py index bbef37d1..295a74ab 100644 --- a/diff_gaussian_rasterization/__init__.py +++ b/diff_gaussian_rasterization/__init__.py @@ -68,8 +68,6 @@ def forward( cov3Ds_precomp, raster_settings.viewmatrix, raster_settings.projmatrix, - raster_settings.tanfovx, - raster_settings.tanfovy, raster_settings.image_height, raster_settings.image_width, sh, @@ -116,8 +114,6 @@ def backward(ctx, grad_out_color, _): cov3Ds_precomp, raster_settings.viewmatrix, raster_settings.projmatrix, - raster_settings.tanfovx, - raster_settings.tanfovy, grad_out_color, sh, raster_settings.sh_degree, @@ -157,8 +153,6 @@ def backward(ctx, grad_out_color, _): class GaussianRasterizationSettings(NamedTuple): image_height: int image_width: int - tanfovx : float - tanfovy : float bg : torch.Tensor scale_modifier : float viewmatrix : torch.Tensor diff --git a/rasterize_points.cu b/rasterize_points.cu index ddc5cf8b..cb52b89a 100644 --- a/rasterize_points.cu +++ b/rasterize_points.cu @@ -44,8 +44,6 @@ RasterizeGaussiansCUDA( const torch::Tensor& cov3D_precomp, const torch::Tensor& viewmatrix, const torch::Tensor& projmatrix, - const float tan_fovx, - const float tan_fovy, const int image_height, const int image_width, const torch::Tensor& sh, @@ -104,8 +102,6 @@ RasterizeGaussiansCUDA( viewmatrix.contiguous().data(), projmatrix.contiguous().data(), campos.contiguous().data(), - tan_fovx, - tan_fovy, prefiltered, out_color.contiguous().data(), radii.contiguous().data(), @@ -126,8 +122,6 @@ std::tuple(), projmatrix.contiguous().data(), campos.contiguous().data(), - tan_fovx, - tan_fovy, radii.contiguous().data(), reinterpret_cast(geomBuffer.contiguous().data_ptr()), reinterpret_cast(binningBuffer.contiguous().data_ptr()), diff --git a/rasterize_points.h b/rasterize_points.h index 9023d994..df5deeea 100644 --- a/rasterize_points.h +++ b/rasterize_points.h @@ -27,8 +27,6 @@ RasterizeGaussiansCUDA( const torch::Tensor& cov3D_precomp, const torch::Tensor& viewmatrix, const torch::Tensor& projmatrix, - const float tan_fovx, - const float tan_fovy, const int image_height, const int image_width, const torch::Tensor& sh, @@ -49,8 +47,6 @@ std::tuple