Skip to content

Commit 4c6d250

Browse files
committed
Add depth backward pass
1 parent 6787328 commit 4c6d250

File tree

7 files changed

+35
-1
lines changed

7 files changed

+35
-1
lines changed

Diff for: cuda_rasterizer/backward.cu

+23
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,11 @@ renderCUDA(
406406
const float2* __restrict__ points_xy_image,
407407
const float4* __restrict__ conic_opacity,
408408
const float* __restrict__ colors,
409+
const float* __restrict__ depths,
409410
const float* __restrict__ final_Ts,
410411
const uint32_t* __restrict__ n_contrib,
411412
const float* __restrict__ dL_dpixels,
413+
const float* __restrict__ dL_depths,
412414
float3* __restrict__ dL_dmean2D,
413415
float4* __restrict__ dL_dconic2D,
414416
float* __restrict__ dL_dopacity,
@@ -435,6 +437,7 @@ renderCUDA(
435437
__shared__ float2 collected_xy[BLOCK_SIZE];
436438
__shared__ float4 collected_conic_opacity[BLOCK_SIZE];
437439
__shared__ float collected_colors[C * BLOCK_SIZE];
440+
__shared__ float collected_depths[BLOCK_SIZE];
438441

439442
// In the forward, we stored the final value for T, the
440443
// product of all (1 - alpha) factors.
@@ -448,12 +451,16 @@ renderCUDA(
448451

449452
float accum_rec[C] = { 0 };
450453
float dL_dpixel[C];
454+
float dL_depth;
455+
float accum_depth_rec = 0;
451456
if (inside)
452457
for (int i = 0; i < C; i++)
453458
dL_dpixel[i] = dL_dpixels[i * H * W + pix_id];
459+
dL_depth = dL_depths[pix_id];
454460

455461
float last_alpha = 0;
456462
float last_color[C] = { 0 };
463+
float last_depth = 0;
457464

458465
// Gradient of pixel coordinate w.r.t. normalized
459466
// screen-space viewport corrdinates (-1 to 1)
@@ -475,6 +482,7 @@ renderCUDA(
475482
collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
476483
for (int i = 0; i < C; i++)
477484
collected_colors[i * BLOCK_SIZE + block.thread_rank()] = colors[coll_id * C + i];
485+
collected_depths[block.thread_rank()] = depths[coll_id];
478486
}
479487
block.sync();
480488

@@ -522,6 +530,17 @@ renderCUDA(
522530
// many that were affected by this Gaussian.
523531
atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel);
524532
}
533+
534+
// Propagate gradients to per-Gaussian depths
535+
const float c_d = collected_depths[j];
536+
accum_depth_rec = last_alpha * last_depth + (1.f - last_alpha) * accum_depth_rec;
537+
last_depth = c_d;
538+
dL_dalpha += (c_d - accum_depth_rec) * dL_depth;
539+
// for (int ch = 0; ch < C; ch++)
540+
// {
541+
// atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_depth);
542+
// }
543+
525544
dL_dalpha *= T;
526545
// Update last alpha (to be used in the next iteration)
527546
last_alpha = alpha;
@@ -630,9 +649,11 @@ void BACKWARD::render(
630649
const float2* means2D,
631650
const float4* conic_opacity,
632651
const float* colors,
652+
const float* depths,
633653
const float* final_Ts,
634654
const uint32_t* n_contrib,
635655
const float* dL_dpixels,
656+
const float* dL_depths,
636657
float3* dL_dmean2D,
637658
float4* dL_dconic2D,
638659
float* dL_dopacity,
@@ -646,9 +667,11 @@ void BACKWARD::render(
646667
means2D,
647668
conic_opacity,
648669
colors,
670+
depths,
649671
final_Ts,
650672
n_contrib,
651673
dL_dpixels,
674+
dL_depths,
652675
dL_dmean2D,
653676
dL_dconic2D,
654677
dL_dopacity,

Diff for: cuda_rasterizer/backward.h

+2
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ namespace BACKWARD
2929
const float2* means2D,
3030
const float4* conic_opacity,
3131
const float* colors,
32+
const float* depths,
3233
const float* final_Ts,
3334
const uint32_t* n_contrib,
3435
const float* dL_dpixels,
36+
const float* dL_depths,
3537
float3* dL_dmean2D,
3638
float4* dL_dconic2D,
3739
float* dL_dopacity,

Diff for: cuda_rasterizer/rasterizer.h

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ namespace CudaRasterizer
7272
char* binning_buffer,
7373
char* image_buffer,
7474
const float* dL_dpix,
75+
const float* dL_depths,
7576
float* dL_dmean2D,
7677
float* dL_dconic,
7778
float* dL_dopacity,

Diff for: cuda_rasterizer/rasterizer_impl.cu

+4
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ void CudaRasterizer::Rasterizer::backward(
360360
char* binning_buffer,
361361
char* img_buffer,
362362
const float* dL_dpix,
363+
const float* dL_depths,
363364
float* dL_dmean2D,
364365
float* dL_dconic,
365366
float* dL_dopacity,
@@ -389,6 +390,7 @@ void CudaRasterizer::Rasterizer::backward(
389390
// opacity and RGB of Gaussians from per-pixel loss gradients.
390391
// If we were given precomputed colors and not SHs, use them.
391392
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb;
393+
const float* depth_ptr = geomState.depths;
392394
BACKWARD::render(
393395
tile_grid,
394396
block,
@@ -399,9 +401,11 @@ void CudaRasterizer::Rasterizer::backward(
399401
geomState.means2D,
400402
geomState.conic_opacity,
401403
color_ptr,
404+
depth_ptr,
402405
imgState.accum_alpha,
403406
imgState.n_contrib,
404407
dL_dpix,
408+
dL_depths,
405409
(float3*)dL_dmean2D,
406410
(float4*)dL_dconic,
407411
dL_dopacity,

Diff for: diff_gaussian_rasterization/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ def backward(ctx, grad_out_color, grad_radii, grad_depth):
104104
raster_settings.projmatrix,
105105
raster_settings.tanfovx,
106106
raster_settings.tanfovy,
107-
grad_out_color,
107+
grad_out_color,
108+
grad_depth,
108109
sh,
109110
raster_settings.sh_degree,
110111
raster_settings.campos,

Diff for: rasterize_points.cu

+2
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
129129
const float tan_fovx,
130130
const float tan_fovy,
131131
const torch::Tensor& dL_dout_color,
132+
const torch::Tensor& dL_dout_depth,
132133
const torch::Tensor& sh,
133134
const int degree,
134135
const torch::Tensor& campos,
@@ -179,6 +180,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
179180
reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
180181
reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
181182
dL_dout_color.contiguous().data<float>(),
183+
dL_dout_depth.contiguous().data<float>(),
182184
dL_dmeans2D.contiguous().data<float>(),
183185
dL_dconic.contiguous().data<float>(),
184186
dL_dopacity.contiguous().data<float>(),

Diff for: rasterize_points.h

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
5151
const float tan_fovx,
5252
const float tan_fovy,
5353
const torch::Tensor& dL_dout_color,
54+
const torch::Tensor& dL_dout_depth,
5455
const torch::Tensor& sh,
5556
const int degree,
5657
const torch::Tensor& campos,

0 commit comments

Comments
 (0)