Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backpropagating through raymarching #119

Open
alex3dfan opened this issue Oct 24, 2022 · 5 comments
Open

Backpropagating through raymarching #119

alex3dfan opened this issue Oct 24, 2022 · 5 comments

Comments

@alex3dfan
Copy link

Hi @ashawkey ,
I wish to backpropagate gradients back to camera poses, however this results in an error in _march_rays_train(), which does not have a backward pass function. To make this work do I need to write the backward pass function in raymarching py and the backward function in raymarching.cu or is there another easier way to reach that goal ?

@maturk
Copy link

maturk commented Oct 26, 2022

I am also interested in learning how to add gradient flow such that rays_o and rays_d have requires_grad = True so that backward pass can be used to optimize camera pose like in iNeRF or more recently NVIDIA's parallel inversion of NeRF's paper.

@ashawkey
Copy link
Owner

@alex3dfan Hi, yes, you have to implement the backward for raymarching, so the gradient from xyzs and dirs can propagate to rays_o and rays_d, which finally get to your trainable camera poses.
I haven't been able to implement and test it recently, but you may check ngp_pl, where they implemented it.

@ShengyuH
Copy link

ShengyuH commented Sep 20, 2023

hi @ashawkey ,

I am writing the backward pass for raymarching following ngp_pl. I am wondering if deltas[:,1] (

deltas = deltas[:m]
) in your repo correspond to ts (https://github.com/kwea123/ngp_pl/blob/1b49af1856a276b236e0f17539814134ed329860/models/custom_functions.py#L96) in ngp_pl repo? I quickly compared two cuda implementations and it seems to me that they are the same thing, can you help confirm this? If this is the case, I think it's pretty straight-forward to implement the backward pass.

Thanks for this open-source contribution!

Best,
Shengyu

@ShengyuH
Copy link

This is my current implementation:

from torch_scatter import segment_csr
from einops import rearrange

class _march_rays_train(Function):
    @staticmethod
    @custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1,
                perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
        ''' march rays to generate points (forward only)
        Args:
            rays_o/d: float, [N, 3]
            bound: float, scalar
            density_bitfield: uint8: [CHHH // 8]
            C: int
            H: int
            nears/fars: float, [N]
            step_counter: int32, (2), used to count the actual number of generated points.
            mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
            perturb: bool
            align: int, pad output so its size is dividable by align, set to -1 to disable.
            force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
            dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
            max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
        Returns:
            xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
            dirs: float, [M, 3], all generated points' view dirs.
            deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth)
            rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 1] + rays[i, 2]] --> points belonging to rays[i, 0]
        '''

        if not rays_o.is_cuda:
            rays_o = rays_o.cuda()
        if not rays_d.is_cuda:
            rays_d = rays_d.cuda()
        if not density_bitfield.is_cuda:
            density_bitfield = density_bitfield.cuda()

        rays_o = rays_o.contiguous().view(-1, 3)
        rays_d = rays_d.contiguous().view(-1, 3)
        density_bitfield = density_bitfield.contiguous()

        N = rays_o.shape[0]  # num rays
        M = N * max_steps  # init max points number in total

        # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)
        # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.
        if not force_all_rays and mean_count > 0:
            if align > 0:
                mean_count += align - mean_count % align
            M = mean_count

        xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
        dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
        deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
        rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device)  # id, offset, num_steps

        if step_counter is None:
            step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device)  # point counter, ray counter

        if perturb:
            noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
        else:
            noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)

        _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars,
                                  xyzs, dirs, deltas, rays, step_counter,
                                  noises)  # m is the actually used points number

        # print(step_counter, M)

        # only used at the first (few) epochs.
        if force_all_rays or mean_count <= 0:
            m = step_counter[0].item()  # D2H copy
            if align > 0:
                m += align - m % align
            xyzs = xyzs[:m]
            dirs = dirs[:m]
            deltas = deltas[:m]

            # torch.cuda.empty_cache()
        ctx.save_for_backward(rays.long(), deltas[:,1])
        
        return xyzs, dirs, deltas, rays
    
    # we follow the implementation of ngp_pl
    @staticmethod
    @custom_bwd
    def backward(ctx, dL_dxyzs, dL_ddirs,
                 dL_ddeltas, dL_drays_a):
        rays_a, ts = ctx.saved_tensors
        segments = torch.cat([rays_a[:, 1], rays_a[-1:, 1]+rays_a[-1:, 2]])
        dL_drays_o = segment_csr(dL_dxyzs, segments)
        dL_drays_d = segment_csr(dL_dxyzs*rearrange(ts, 'n -> n 1')+dL_ddirs, segments)

        return dL_drays_o, dL_drays_d, None, None, None, None, None, None, None, None, None, None, None, None, None

@QitaoZhao
Copy link

@ShengyuH Hi, Shengyu, I wonder if the gradient works properly for camera optimization with the code you provided above? I am trying to do something similar but still encountered NotImplementedError: You must implement either the backward or vjp method for your custom autograd.Function to use it with backward mode AD. I'm not sure if the issue is with this function. Looking forward to your reply!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants