-
Notifications
You must be signed in to change notification settings - Fork 3
/
samplers.py
115 lines (102 loc) · 4.95 KB
/
samplers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from typing import Optional, Tuple, List
def sample_stratified(
rays_o: torch.Tensor, # [n_theta, n_phi, 3]
rays_d: torch.Tensor, # [n_theta, n_phi, 3]
near: float,
far: float,
n_samples: int,
perturb: Optional[bool] = True,
inverse_depth: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Sample along ray from regularly-spaced bins.
Borrowed from: Mason McGough
Source: https://towardsdatascience.com/its-nerf-from-nothing-build-a-vanilla-nerf-with-pytorch-7846e4c45666
"""
# Grab samples for space integration along ray
t_vals = torch.linspace(0., 1., n_samples, device=rays_o.device)
if not inverse_depth:
# Sample linearly between `near` and `far`
z_vals = near * (1.-t_vals) + far * (t_vals)
else:
# Sample linearly in inverse depth (disparity)
z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
# Draw uniform samples from bins along ray
if perturb:
mids = .5 * (z_vals[1:] + z_vals[:-1])
upper = torch.concat([mids, z_vals[-1:]], dim=-1)
lower = torch.concat([z_vals[:1], mids], dim=-1)
t_rand = torch.rand([n_samples], device=z_vals.device)
z_vals = lower + (upper - lower) * t_rand
# Make sure the z_vals starts from 0
z_vals[0] = 1e-10
z_vals = z_vals.expand(list(rays_o.shape[:-1]) + [n_samples]) # [n_theta, n_phi, n_samples]
# Apply scale from `rays_d` and offset from `rays_o` to samples
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
return pts, z_vals
def sample_pdf(
bins: torch.Tensor,
weights: torch.Tensor,
n_samples: int,
perturb: bool = False
) -> torch.Tensor:
r"""
Apply inverse transform sampling to a weighted set of points.
Borrowed from: Mason McGough
Source: https://towardsdatascience.com/its-nerf-from-nothing-build-a-vanilla-nerf-with-pytorch-7846e4c45666
"""
# Normalize weights to get PDF.
pdf = (weights + 1e-5) / torch.sum(weights + 1e-5, -1, keepdims=True) # [n_rays, weights.shape[-1]]
# Convert PDF to CDF.
cdf = torch.cumsum(pdf, dim=-1) # [n_rays, weights.shape[-1]]
cdf = torch.concat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1) # [n_rays, weights.shape[-1] + 1]
# Take sample positions to grab from CDF. Linear when perturb == 0.
if not perturb:
u = torch.linspace(0., 1., n_samples, device=cdf.device)
u = u.expand(list(cdf.shape[:-1]) + [n_samples]) # [n_rays, n_samples]
else:
u = torch.rand(list(cdf.shape[:-1]) + [n_samples], device=cdf.device) # [n_rays, n_samples]
# Find indices along CDF where values in u would be placed.
u = u.contiguous() # Returns contiguous tensor with same values.
inds = torch.searchsorted(cdf, u, right=True) # [n_rays, n_samples]
# Clamp indices that are out of bounds.
below = torch.clamp(inds - 1, min=0)
above = torch.clamp(inds, max=cdf.shape[-1] - 1)
inds_g = torch.stack([below, above], dim=-1) # [n_rays, n_samples, 2]
# Sample from cdf and the corresponding bin centers.
matched_shape = list(inds_g.shape[:-1]) + [cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), dim=-1,
index=inds_g)
bins_g = torch.gather(bins.unsqueeze(-2).expand(matched_shape), dim=-1,
index=inds_g)
# Convert samples to ray length.
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples # [n_rays, n_samples]
def sample_hierarchical(
rays_o: torch.Tensor,
rays_d: torch.Tensor,
z_vals: torch.Tensor,
weights: torch.Tensor,
n_new_samples: int,
perturb: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Apply hierarchical sampling to the rays.
Borrowed from: Mason McGough
Source: https://towardsdatascience.com/its-nerf-from-nothing-build-a-vanilla-nerf-with-pytorch-7846e4c45666
"""
# Draw samples from PDF using z_vals as bins and weights as probabilities.
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) # [N_rays, N_samples_stratified-1] --> N_bins = N_samples_stratified-1
new_z_samples = sample_pdf(z_vals_mid,
weights[..., 1:-1], # [N_rays, N_samples_stratified-2]
n_new_samples,
perturb=perturb)
new_z_samples = new_z_samples.detach()
# Resample points from ray based on PDF.
z_vals_combined, _ = torch.sort(torch.cat([z_vals, new_z_samples], dim=-1), dim=-1)
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals_combined[..., :, None] # [N_rays, N_samples + n_samples, 3]
return pts, z_vals_combined, rays_o, rays_d