diff --git a/README.md b/README.md index 9c41dec..97658dd 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,47 @@ # Rethinking Inductive Biases for Surface Normal Estimation +

+ +

+ Official implementation of the paper > **Rethinking Inductive Biases for Surface Normal Estimation** > > CVPR 2024 (to appear) > -> [Gwangbin Bae](https://baegwangbin.com) and [Andrew J. Davison](https://www.doc.ic.ac.uk/~ajd/) +> Gwangbin Bae and > Andrew J. Davison > -> [[arXiv]]() [[project page]]() \ No newline at end of file +> [paper.pdf] +[arXiv (coming soon)] +[project page] + +## Abstract + +Despite the growing demand for accurate surface normal estimation models, existing methods use general-purpose dense prediction models, adopting the same inductive biases as other tasks. In this paper, we discuss the **inductive biases** needed for surface normal estimation and propose to **(1) utilize the per-pixel ray direction** and **(2) encode the relationship between neighboring surface normals by learning their relative rotation**. The proposed method can generate **crisp — yet, piecewise smooth — predictions** for challenging in-the-wild images of arbitrary resolution and aspect ratio. Compared to a recent ViT-based state-of-the-art model, our method shows a stronger generalization ability, despite being trained on an orders of magnitude smaller dataset. + +

+ +

+ +## Getting Started + +Start by installing the dependencies. + +``` +conda create --name DSINE python=3.10 +conda activate DSINE + +conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia +conda install opencv +python -m pip install geffnet +python -m pip install glob2 +``` + +Then, download the model weights from this link and save it under `./checkpoints/`. + +## Test on images + +* Run `python test.py` to generate predictions for the images under `./samples/img/`. The result will be saved under `./samples/output/`. +* Our model assumes known camera intrinsics, but providing approximate intrinsics still gives good results. For some images in `./samples/img/`, the corresponding camera intrinsics (fx, fy, cx, cy - assuming perspective camera with no distortion) is provided as a `.txt` file. If such a file does not exist, the intrinsics will be approximated, by assuming $60^\circ$ field-of-view. \ No newline at end of file diff --git a/docs/index.html b/docs/index.html index f28ecf9..b0e0f3b 100644 --- a/docs/index.html +++ b/docs/index.html @@ -12,12 +12,16 @@ + - + MathJax.Hub.Config({ + tex2jax: { + inlineMath: [ ['$','$'], ["\\(","\\)"] ], + processEscapes: true + } + }); + + @@ -110,7 +114,7 @@

Demo

-

The input videos are from DAVIS. The predictions are made per-frame.

+

The input videos are from DAVIS. The predictions are made per-frame (we recommend watching in 4K).

diff --git a/models/dsine.py b/models/dsine.py new file mode 100644 index 0000000..2ace18d --- /dev/null +++ b/models/dsine.py @@ -0,0 +1,233 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.submodules import Encoder, ConvGRU, UpSampleBN, UpSampleGN, RayReLU, \ + convex_upsampling, get_unfold, get_prediction_head, \ + INPUT_CHANNELS_DICT +from utils.rotation import axis_angle_to_matrix + + +class Decoder(nn.Module): + def __init__(self, output_dims, B=5, NF=2048, BN=False, downsample_ratio=8): + super(Decoder, self).__init__() + input_channels = INPUT_CHANNELS_DICT[B] + output_dim, feature_dim, hidden_dim = output_dims + features = bottleneck_features = NF + self.downsample_ratio = downsample_ratio + + UpSample = UpSampleBN if BN else UpSampleGN + self.conv2 = nn.Conv2d(bottleneck_features + 2, features, kernel_size=1, stride=1, padding=0) + self.up1 = UpSample(skip_input=features // 1 + input_channels[1] + 2, output_features=features // 2, align_corners=False) + self.up2 = UpSample(skip_input=features // 2 + input_channels[2] + 2, output_features=features // 4, align_corners=False) + + # prediction heads + i_dim = features // 4 + h_dim = 128 + self.normal_head = get_prediction_head(i_dim+2, h_dim, output_dim) + self.feature_head = get_prediction_head(i_dim+2, h_dim, feature_dim) + self.hidden_head = get_prediction_head(i_dim+2, h_dim, hidden_dim) + + def forward(self, features, uvs): + _, _, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11] + uv_32, uv_16, uv_8 = uvs + + x_d0 = self.conv2(torch.cat([x_block4, uv_32], dim=1)) + x_d1 = self.up1(x_d0, torch.cat([x_block3, uv_16], dim=1)) + x_feat = self.up2(x_d1, torch.cat([x_block2, uv_8], dim=1)) + x_feat = torch.cat([x_feat, uv_8], dim=1) + + normal = self.normal_head(x_feat) + normal = F.normalize(normal, dim=1) + f = self.feature_head(x_feat) + h = self.hidden_head(x_feat) + return normal, f, h + + +class DSINE(nn.Module): + def __init__(self): + super(DSINE, self).__init__() + self.downsample_ratio = 8 + self.ps = 5 # patch size + self.num_iter = 5 # num iterations + + # define encoder + self.encoder = Encoder(B=5, pretrained=True) + + # define decoder + self.output_dim = output_dim = 3 + self.feature_dim = feature_dim = 64 + self.hidden_dim = hidden_dim = 64 + self.decoder = Decoder([output_dim, feature_dim, hidden_dim], B=5, NF=2048, BN=False) + + # ray direction-based ReLU + self.ray_relu = RayReLU(eps=1e-2) + + # pixel_coords (1, 3, H, W) + # NOTE: this is set to some arbitrarily high number, + # if your input is 2000+ pixels wide/tall, increase these values + h = 2000 + w = 2000 + pixel_coords = np.ones((3, h, w)).astype(np.float32) + x_range = np.concatenate([np.arange(w).reshape(1, w)] * h, axis=0) + y_range = np.concatenate([np.arange(h).reshape(h, 1)] * w, axis=1) + pixel_coords[0, :, :] = x_range + 0.5 + pixel_coords[1, :, :] = y_range + 0.5 + self.pixel_coords = torch.from_numpy(pixel_coords).unsqueeze(0) + + # define ConvGRU cell + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=feature_dim+2, ks=self.ps) + + # padding used during NRN + self.pad = (self.ps - 1) // 2 + + # prediction heads + self.prob_head = get_prediction_head(self.hidden_dim+2, 64, self.ps*self.ps) # weights assigned for each nghbr pixel + self.xy_head = get_prediction_head(self.hidden_dim+2, 64, self.ps*self.ps*2) # rotation axis for each nghbr pixel + self.angle_head = get_prediction_head(self.hidden_dim+2, 64, self.ps*self.ps) # rotation angle for each nghbr pixel + + # prediction heads - weights used for upsampling the coarse resolution output + self.up_prob_head = get_prediction_head(self.hidden_dim+2, 64, 9 * self.downsample_ratio * self.downsample_ratio) + + def get_ray(self, intrins, H, W, orig_H, orig_W, return_uv=False): + B, _, _ = intrins.shape + fu = intrins[:, 0, 0][:,None,None] * (W / orig_W) + cu = intrins[:, 0, 2][:,None,None] * (W / orig_W) + fv = intrins[:, 1, 1][:,None,None] * (H / orig_H) + cv = intrins[:, 1, 2][:,None,None] * (H / orig_H) + + # (B, 2, H, W) + ray = self.pixel_coords[:, :, :H, :W].repeat(B, 1, 1, 1) + ray[:, 0, :, :] = (ray[:, 0, :, :] - cu) / fu + ray[:, 1, :, :] = (ray[:, 1, :, :] - cv) / fv + + if return_uv: + return ray[:, :2, :, :] + else: + return F.normalize(ray, dim=1) + + def upsample(self, h, pred_norm, uv_8): + up_mask = self.up_prob_head(torch.cat([h, uv_8], dim=1)) + up_pred_norm = convex_upsampling(pred_norm, up_mask, self.downsample_ratio) + up_pred_norm = F.normalize(up_pred_norm, dim=1) + return up_pred_norm + + def refine(self, h, feat_map, pred_norm, intrins, orig_H, orig_W, uv_8, ray_8): + B, C, H, W = pred_norm.shape + fu = intrins[:, 0, 0][:,None,None,None] * (W / orig_W) # (B, 1, 1, 1) + cu = intrins[:, 0, 2][:,None,None,None] * (W / orig_W) + fv = intrins[:, 1, 1][:,None,None,None] * (H / orig_H) + cv = intrins[:, 1, 2][:,None,None,None] * (H / orig_H) + + h_new = self.gru(h, feat_map) + + # get nghbr prob (B, 1, ps*ps, h, w) + nghbr_prob = self.prob_head(torch.cat([h_new, uv_8], dim=1)).unsqueeze(1) + nghbr_prob = torch.sigmoid(nghbr_prob) + + # get nghbr normals (B, 3, ps*ps, h, w) + nghbr_normals = get_unfold(pred_norm, ps=self.ps, pad=self.pad) + + # get nghbr xy (B, 2, ps*ps, h, w) + nghbr_xys = self.xy_head(torch.cat([h_new, uv_8], dim=1)) + nghbr_xs, nghbr_ys = torch.split(nghbr_xys, [self.ps*self.ps, self.ps*self.ps], dim=1) + nghbr_xys = torch.cat([nghbr_xs.unsqueeze(1), nghbr_ys.unsqueeze(1)], dim=1) + nghbr_xys = F.normalize(nghbr_xys, dim=1) + + # get nghbr theta (B, 1, ps*ps, h, w) + nghbr_angle = self.angle_head(torch.cat([h_new, uv_8], dim=1)).unsqueeze(1) + nghbr_angle = torch.sigmoid(nghbr_angle) * np.pi + + # get nghbr pixel coord (1, 3, ps*ps, h, w) + nghbr_pixel_coord = get_unfold(self.pixel_coords[:, :, :H, :W], ps=self.ps, pad=self.pad) + + # nghbr axes (B, 3, ps*ps, h, w) + nghbr_axes = torch.zeros_like(nghbr_normals) + + du_over_fu = nghbr_xys[:, 0, ...] / fu # (B, ps*ps, h, w) + dv_over_fv = nghbr_xys[:, 1, ...] / fv # (B, ps*ps, h, w) + + term_u = (nghbr_pixel_coord[:, 0, ...] + nghbr_xys[:, 0, ...] - cu) / fu # (B, ps*ps, h, w) + term_v = (nghbr_pixel_coord[:, 1, ...] + nghbr_xys[:, 1, ...] - cv) / fv # (B, ps*ps, h, w) + + nx = nghbr_normals[:, 0, ...] # (B, ps*ps, h, w) + ny = nghbr_normals[:, 1, ...] # (B, ps*ps, h, w) + nz = nghbr_normals[:, 2, ...] # (B, ps*ps, h, w) + + nghbr_delta_z_num = - (du_over_fu * nx + dv_over_fv * ny) + nghbr_delta_z_denom = (term_u * nx + term_v * ny + nz) + nghbr_delta_z_denom[torch.abs(nghbr_delta_z_denom) < 1e-8] = 1e-8 * torch.sign(nghbr_delta_z_denom[torch.abs(nghbr_delta_z_denom) < 1e-8]) + nghbr_delta_z = nghbr_delta_z_num / nghbr_delta_z_denom + + nghbr_axes[:, 0, ...] = du_over_fu + nghbr_delta_z * term_u + nghbr_axes[:, 1, ...] = dv_over_fv + nghbr_delta_z * term_v + nghbr_axes[:, 2, ...] = nghbr_delta_z + nghbr_axes = F.normalize(nghbr_axes, dim=1) # (B, 3, ps*ps, h, w) + + # make sure axes are all valid + invalid = torch.sum(torch.logical_or(torch.isnan(nghbr_axes), torch.isinf(nghbr_axes)).float(), dim=1) > 0.5 # (B, ps*ps, h, w) + nghbr_axes[:, 0, ...][invalid] = 0.0 + nghbr_axes[:, 1, ...][invalid] = 0.0 + nghbr_axes[:, 2, ...][invalid] = 0.0 + + # nghbr_axes_angle (B, 3, ps*ps, h, w) + nghbr_axes_angle = nghbr_axes * nghbr_angle + nghbr_axes_angle = nghbr_axes_angle.permute(0, 2, 3, 4, 1) # (B, ps*ps, h, w, 3) + nghbr_R = axis_angle_to_matrix(nghbr_axes_angle) # (B, ps*ps, h, w, 3, 3) + + # (B, 3, ps*ps, h, w) + nghbr_normals_rot = torch.bmm( + nghbr_R.reshape(B * self.ps * self.ps * H * W, 3, 3), + nghbr_normals.permute(0, 2, 3, 4, 1).reshape(B * self.ps * self.ps * H * W, 3).unsqueeze(-1) + ).reshape(B, self.ps*self.ps, H, W, 3, 1).squeeze(-1).permute(0, 4, 1, 2, 3) # (B, 3, ps*ps, h, w) + nghbr_normals_rot = F.normalize(nghbr_normals_rot, dim=1) + + # ray ReLU + nghbr_normals_rot = torch.cat([ + self.ray_relu(nghbr_normals_rot[:, :, i, :, :], ray_8).unsqueeze(2) + for i in range(nghbr_normals_rot.size(2)) + ], dim=2) + + # (B, 1, ps*ps, h, w) * (B, 3, ps*ps, h, w) + pred_norm = torch.sum(nghbr_prob * nghbr_normals_rot, dim=2) # (B, C, H, W) + pred_norm = F.normalize(pred_norm, dim=1) + + up_mask = self.up_prob_head(torch.cat([h_new, uv_8], dim=1)) + up_pred_norm = convex_upsampling(pred_norm, up_mask, self.downsample_ratio) + up_pred_norm = F.normalize(up_pred_norm, dim=1) + + return h_new, pred_norm, up_pred_norm + + + def forward(self, img, intrins=None): + # Step 1. encoder + features = self.encoder(img) + + # Step 2. get uv encoding + B, _, orig_H, orig_W = img.shape + intrins[:, 0, 2] += 0.5 + intrins[:, 1, 2] += 0.5 + uv_32 = self.get_ray(intrins, orig_H//32, orig_W//32, orig_H, orig_W, return_uv=True) + uv_16 = self.get_ray(intrins, orig_H//16, orig_W//16, orig_H, orig_W, return_uv=True) + uv_8 = self.get_ray(intrins, orig_H//8, orig_W//8, orig_H, orig_W, return_uv=True) + ray_8 = self.get_ray(intrins, orig_H//8, orig_W//8, orig_H, orig_W) + + # Step 3. decoder - initial prediction + pred_norm, feat_map, h = self.decoder(features, uvs=(uv_32, uv_16, uv_8)) + pred_norm = self.ray_relu(pred_norm, ray_8) + + # Step 4. add ray direction encoding + feat_map = torch.cat([feat_map, uv_8], dim=1) + + # iterative refinement + up_pred_norm = self.upsample(h, pred_norm, uv_8) + pred_list = [up_pred_norm] + for i in range(self.num_iter): + h, pred_norm, up_pred_norm = self.refine(h, feat_map, + pred_norm.detach(), + intrins, orig_H, orig_W, uv_8, ray_8) + pred_list.append(up_pred_norm) + return pred_list + + diff --git a/models/submodules.py b/models/submodules.py new file mode 100644 index 0000000..4e3afba --- /dev/null +++ b/models/submodules.py @@ -0,0 +1,194 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import geffnet + + +INPUT_CHANNELS_DICT = { + 0: [1280, 112, 40, 24, 16], + 1: [1280, 112, 40, 24, 16], + 2: [1408, 120, 48, 24, 16], + 3: [1536, 136, 48, 32, 24], + 4: [1792, 160, 56, 32, 24], + 5: [2048, 176, 64, 40, 24], + 6: [2304, 200, 72, 40, 32], + 7: [2560, 224, 80, 48, 32] +} + + +class Encoder(nn.Module): + def __init__(self, B=5, pretrained=True): + """ e.g. B=5 will return EfficientNet-B5 + """ + super(Encoder, self).__init__() + basemodel = geffnet.create_model('tf_efficientnet_b%s_ap' % B, pretrained=pretrained) + # Remove last layer + basemodel.global_pool = nn.Identity() + basemodel.classifier = nn.Identity() + self.original_model = basemodel + + def forward(self, x): + features = [x] + for k, v in self.original_model._modules.items(): + if (k == 'blocks'): + for ki, vi in v._modules.items(): + features.append(vi(features[-1])) + else: + features.append(v(features[-1])) + return features + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim, input_dim, ks=3): + super(ConvGRU, self).__init__() + p = (ks - 1) // 2 + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, ks, padding=p) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, ks, padding=p) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, ks, padding=p) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + return h + + +class RayReLU(nn.Module): + def __init__(self, eps=1e-2): + super(RayReLU, self).__init__() + self.eps = eps + + def forward(self, pred_norm, ray): + # angle between the predicted normal and ray direction + cos = torch.cosine_similarity(pred_norm, ray, dim=1).unsqueeze(1) # (B, 1, H, W) + + # component of pred_norm along view + norm_along_view = ray * cos + + # cos should be bigger than eps + norm_along_view_relu = ray * (torch.relu(cos - self.eps) + self.eps) + + # difference + diff = norm_along_view_relu - norm_along_view + + # updated pred_norm + new_pred_norm = pred_norm + diff + new_pred_norm = F.normalize(new_pred_norm, dim=1) + + return new_pred_norm + + +class UpSampleBN(nn.Module): + def __init__(self, skip_input, output_features, align_corners=True): + super(UpSampleBN, self).__init__() + self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU(), + nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU()) + self.align_corners = align_corners + + def forward(self, x, concat_with): + up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=self.align_corners) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +class Conv2d_WS(nn.Conv2d): + """ weight standardization + """ + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2d_WS, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + + def forward(self, x): + weight = self.weight + weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, + keepdim=True).mean(dim=3, keepdim=True) + weight = weight - weight_mean + std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 + weight = weight / std.expand_as(weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class UpSampleGN(nn.Module): + """ UpSample with GroupNorm + """ + def __init__(self, skip_input, output_features, align_corners=True): + super(UpSampleGN, self).__init__() + self._net = nn.Sequential(Conv2d_WS(skip_input, output_features, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(8, output_features), + nn.LeakyReLU(), + Conv2d_WS(output_features, output_features, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(8, output_features), + nn.LeakyReLU()) + self.align_corners = align_corners + + def forward(self, x, concat_with): + up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=self.align_corners) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +def upsample_via_bilinear(out, up_mask, downsample_ratio): + """ bilinear upsampling (up_mask is a dummy variable) + """ + return F.interpolate(out, scale_factor=downsample_ratio, mode='bilinear', align_corners=True) + + +def upsample_via_mask(out, up_mask, downsample_ratio): + """ convex upsampling + """ + # out: low-resolution output (B, o_dim, H, W) + # up_mask: (B, 9*k*k, H, W) + k = downsample_ratio + + N, o_dim, H, W = out.shape + up_mask = up_mask.view(N, 1, 9, k, k, H, W) + up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W) + + up_out = F.unfold(out, [3, 3], padding=1) # (B, 2, H, W) -> (B, 2 X 3*3, H*W) + up_out = up_out.view(N, o_dim, 9, 1, 1, H, W) # (B, 2, 3*3, 1, 1, H, W) + up_out = torch.sum(up_mask * up_out, dim=2) # (B, 2, k, k, H, W) + + up_out = up_out.permute(0, 1, 4, 2, 5, 3) # (B, 2, H, k, W, k) + return up_out.reshape(N, o_dim, k*H, k*W) # (B, 2, kH, kW) + + +def convex_upsampling(out, up_mask, k): + # out: low-resolution output (B, C, H, W) + # up_mask: (B, 9*k*k, H, W) + B, C, H, W = out.shape + up_mask = up_mask.view(B, 1, 9, k, k, H, W) + up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W) + + out = F.pad(out, pad=(1,1,1,1), mode='replicate') + up_out = F.unfold(out, [3, 3], padding=0) # (B, C, H, W) -> (B, C X 3*3, H*W) + up_out = up_out.view(B, C, 9, 1, 1, H, W) # (B, C, 9, 1, 1, H, W) + + up_out = torch.sum(up_mask * up_out, dim=2) # (B, C, k, k, H, W) + up_out = up_out.permute(0, 1, 4, 2, 5, 3) # (B, C, H, k, W, k) + return up_out.reshape(B, C, k*H, k*W) # (B, C, kH, kW) + + +def get_unfold(pred_norm, ps, pad): + B, C, H, W = pred_norm.shape + pred_norm = F.pad(pred_norm, pad=(pad,pad,pad,pad), mode='replicate') # (B, C, h, w) + pred_norm_unfold = F.unfold(pred_norm, [ps, ps], padding=0) # (B, C X ps*ps, h*w) + pred_norm_unfold = pred_norm_unfold.view(B, C, ps*ps, H, W) # (B, C, ps*ps, h, w) + return pred_norm_unfold + + +def get_prediction_head(input_dim, hidden_dim, output_dim): + return nn.Sequential( + nn.Conv2d(input_dim, hidden_dim, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, output_dim, 1), + ) + diff --git a/samples/img/146391_DT.png b/samples/img/146391_DT.png new file mode 100755 index 0000000..c4b3e41 Binary files /dev/null and b/samples/img/146391_DT.png differ diff --git a/samples/img/169332_DT.png b/samples/img/169332_DT.png new file mode 100755 index 0000000..48a3759 Binary files /dev/null and b/samples/img/169332_DT.png differ diff --git a/samples/img/197990_DT.png b/samples/img/197990_DT.png new file mode 100755 index 0000000..142a959 Binary files /dev/null and b/samples/img/197990_DT.png differ diff --git a/samples/img/224596_DT.png b/samples/img/224596_DT.png new file mode 100755 index 0000000..1af3426 Binary files /dev/null and b/samples/img/224596_DT.png differ diff --git a/samples/img/228364_ALI.png b/samples/img/228364_ALI.png new file mode 100755 index 0000000..711a501 Binary files /dev/null and b/samples/img/228364_ALI.png differ diff --git a/samples/img/228466_ALI.png b/samples/img/228466_ALI.png new file mode 100755 index 0000000..29cc26c Binary files /dev/null and b/samples/img/228466_ALI.png differ diff --git a/samples/img/232334_ALI.png b/samples/img/232334_ALI.png new file mode 100755 index 0000000..5fd24e5 Binary files /dev/null and b/samples/img/232334_ALI.png differ diff --git a/samples/img/office_01.png b/samples/img/office_01.png new file mode 100755 index 0000000..03ae674 Binary files /dev/null and b/samples/img/office_01.png differ diff --git a/samples/img/office_01.txt b/samples/img/office_01.txt new file mode 100755 index 0000000..41bae58 --- /dev/null +++ b/samples/img/office_01.txt @@ -0,0 +1 @@ +559.62,558.14,361.87,241.99 diff --git a/samples/img/office_02.png b/samples/img/office_02.png new file mode 100755 index 0000000..00d50e3 Binary files /dev/null and b/samples/img/office_02.png differ diff --git a/samples/img/office_02.txt b/samples/img/office_02.txt new file mode 100755 index 0000000..41bae58 --- /dev/null +++ b/samples/img/office_02.txt @@ -0,0 +1 @@ +559.62,558.14,361.87,241.99 diff --git a/samples/img/office_03.png b/samples/img/office_03.png new file mode 100755 index 0000000..39efb50 Binary files /dev/null and b/samples/img/office_03.png differ diff --git a/samples/img/office_03.txt b/samples/img/office_03.txt new file mode 100755 index 0000000..41bae58 --- /dev/null +++ b/samples/img/office_03.txt @@ -0,0 +1 @@ +559.62,558.14,361.87,241.99 diff --git a/samples/img/office_04.png b/samples/img/office_04.png new file mode 100755 index 0000000..60c6824 Binary files /dev/null and b/samples/img/office_04.png differ diff --git a/samples/img/office_04.txt b/samples/img/office_04.txt new file mode 100755 index 0000000..41bae58 --- /dev/null +++ b/samples/img/office_04.txt @@ -0,0 +1 @@ +559.62,558.14,361.87,241.99 diff --git a/samples/img/office_05.png b/samples/img/office_05.png new file mode 100755 index 0000000..2852f8b Binary files /dev/null and b/samples/img/office_05.png differ diff --git a/samples/img/office_05.txt b/samples/img/office_05.txt new file mode 100755 index 0000000..41bae58 --- /dev/null +++ b/samples/img/office_05.txt @@ -0,0 +1 @@ +559.62,558.14,361.87,241.99 diff --git a/samples/img/office_06.png b/samples/img/office_06.png new file mode 100755 index 0000000..65a97d0 Binary files /dev/null and b/samples/img/office_06.png differ diff --git a/samples/img/office_06.txt b/samples/img/office_06.txt new file mode 100755 index 0000000..41bae58 --- /dev/null +++ b/samples/img/office_06.txt @@ -0,0 +1 @@ +559.62,558.14,361.87,241.99 diff --git a/samples/img/office_07.png b/samples/img/office_07.png new file mode 100755 index 0000000..b605555 Binary files /dev/null and b/samples/img/office_07.png differ diff --git a/samples/img/office_07.txt b/samples/img/office_07.txt new file mode 100755 index 0000000..41bae58 --- /dev/null +++ b/samples/img/office_07.txt @@ -0,0 +1 @@ +559.62,558.14,361.87,241.99 diff --git a/samples/img/office_08.png b/samples/img/office_08.png new file mode 100755 index 0000000..3a55946 Binary files /dev/null and b/samples/img/office_08.png differ diff --git a/samples/img/office_08.txt b/samples/img/office_08.txt new file mode 100755 index 0000000..41bae58 --- /dev/null +++ b/samples/img/office_08.txt @@ -0,0 +1 @@ +559.62,558.14,361.87,241.99 diff --git a/samples/output/146391_DT.png b/samples/output/146391_DT.png new file mode 100644 index 0000000..e397d9a Binary files /dev/null and b/samples/output/146391_DT.png differ diff --git a/samples/output/169332_DT.png b/samples/output/169332_DT.png new file mode 100644 index 0000000..8e1d11b Binary files /dev/null and b/samples/output/169332_DT.png differ diff --git a/samples/output/197990_DT.png b/samples/output/197990_DT.png new file mode 100644 index 0000000..20d54af Binary files /dev/null and b/samples/output/197990_DT.png differ diff --git a/samples/output/224596_DT.png b/samples/output/224596_DT.png new file mode 100644 index 0000000..81c6863 Binary files /dev/null and b/samples/output/224596_DT.png differ diff --git a/samples/output/228364_ALI.png b/samples/output/228364_ALI.png new file mode 100644 index 0000000..cfd31a1 Binary files /dev/null and b/samples/output/228364_ALI.png differ diff --git a/samples/output/228466_ALI.png b/samples/output/228466_ALI.png new file mode 100644 index 0000000..a0449eb Binary files /dev/null and b/samples/output/228466_ALI.png differ diff --git a/samples/output/232334_ALI.png b/samples/output/232334_ALI.png new file mode 100644 index 0000000..2d100ae Binary files /dev/null and b/samples/output/232334_ALI.png differ diff --git a/samples/output/office_01.png b/samples/output/office_01.png new file mode 100644 index 0000000..279b366 Binary files /dev/null and b/samples/output/office_01.png differ diff --git a/samples/output/office_02.png b/samples/output/office_02.png new file mode 100644 index 0000000..6dfc4b5 Binary files /dev/null and b/samples/output/office_02.png differ diff --git a/samples/output/office_03.png b/samples/output/office_03.png new file mode 100644 index 0000000..271fb2a Binary files /dev/null and b/samples/output/office_03.png differ diff --git a/samples/output/office_04.png b/samples/output/office_04.png new file mode 100644 index 0000000..7937dba Binary files /dev/null and b/samples/output/office_04.png differ diff --git a/samples/output/office_05.png b/samples/output/office_05.png new file mode 100644 index 0000000..0ed20bf Binary files /dev/null and b/samples/output/office_05.png differ diff --git a/samples/output/office_06.png b/samples/output/office_06.png new file mode 100644 index 0000000..db90125 Binary files /dev/null and b/samples/output/office_06.png differ diff --git a/samples/output/office_07.png b/samples/output/office_07.png new file mode 100644 index 0000000..a4d1261 Binary files /dev/null and b/samples/output/office_07.png differ diff --git a/samples/output/office_08.png b/samples/output/office_08.png new file mode 100644 index 0000000..15f623c Binary files /dev/null and b/samples/output/office_08.png differ diff --git a/test.py b/test.py new file mode 100644 index 0000000..9e0e383 --- /dev/null +++ b/test.py @@ -0,0 +1,77 @@ +import os +import sys +import glob +import argparse +import numpy as np + +import torch +import torch.nn.functional as F +from torchvision import transforms +from PIL import Image +import utils.utils as utils + + +def test_samples(args, model, intrins=None, device='cpu'): + img_paths = glob.glob('./samples/img/*.png') + glob.glob('./samples/img/*.jpg') + img_paths.sort() + + # normalize + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + with torch.no_grad(): + for img_path in img_paths: + print(img_path) + ext = os.path.splitext(img_path)[1] + img = Image.open(img_path).convert('RGB') + img = np.array(img).astype(np.float32) / 255.0 + img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(device) + _, _, orig_H, orig_W = img.shape + + # zero-pad the input image so that both the width and height are multiples of 32 + l, r, t, b = utils.pad_input(orig_H, orig_W) + img = F.pad(img, (l, r, t, b), mode="constant", value=0.0) + img = normalize(img) + + intrins_path = img_path.replace(ext, '.txt') + if os.path.exists(intrins_path): + # NOTE: camera intrinsics should be given as a txt file + # it should contain the values of fx, fy, cx, cy + intrins = utils.get_intrins_from_txt(intrins_path, device=device).unsqueeze(0) + else: + # NOTE: if intrins is not given, we just assume that the principal point is at the center + # and that the field-of-view is 60 degrees (feel free to modify this assumption) + intrins = utils.get_intrins_from_fov(new_fov=60.0, H=orig_H, W=orig_W, device=device).unsqueeze(0) + + intrins[:, 0, 2] += l + intrins[:, 1, 2] += t + + pred_norm = model(img, intrins=intrins)[-1] + pred_norm = pred_norm[:, :, t:t+orig_H, l:l+orig_W] + + # save to output folder + # NOTE: by saving the prediction as uint8 png format, you lose a lot of precision + # if you want to use the predicted normals for downstream tasks, we recommend saving them as float32 NPY files + pred_norm_np = pred_norm.cpu().detach().numpy()[0,:,:,:].transpose(1, 2, 0) # (H, W, 3) + pred_norm_np = ((pred_norm_np + 1.0) / 2.0 * 255.0).astype(np.uint8) + target_path = img_path.replace('/img/', '/output/').replace(ext, '.png') + im = Image.fromarray(pred_norm_np) + im.save(target_path) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--ckpt', default='dsine', type=str, help='model checkpoint') + parser.add_argument('--mode', default='samples', type=str, help='{samples}') + args = parser.parse_args() + + # define model + device = torch.device('cuda') + + from models.dsine import DSINE + model = DSINE().to(device) + model.pixel_coords = model.pixel_coords.to(device) + model = utils.load_checkpoint('./checkpoints/%s.pt' % args.ckpt, model) + model.eval() + + if args.mode == 'samples': + test_samples(args, model, intrins=None, device=device) diff --git a/utils/rotation.py b/utils/rotation.py new file mode 100644 index 0000000..8123f90 --- /dev/null +++ b/utils/rotation.py @@ -0,0 +1,85 @@ +import torch +import numpy as np + + +# NOTE: from PyTorch3D +def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = angles * 0.5 + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +# NOTE: from PyTorch3D +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +# NOTE: from PyTorch3D +def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) \ No newline at end of file diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..61cd4cc --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,105 @@ +""" utils +""" +import os +import torch +import numpy as np + + +def load_checkpoint(fpath, model): + print('loading checkpoint... {}'.format(fpath)) + + ckpt = torch.load(fpath, map_location='cpu')['model'] + + load_dict = {} + for k, v in ckpt.items(): + if k.startswith('module.'): + k_ = k.replace('module.', '') + load_dict[k_] = v + else: + load_dict[k] = v + + model.load_state_dict(load_dict) + print('loading checkpoint... / done') + return model + + +def compute_normal_error(pred_norm, gt_norm): + pred_error = torch.cosine_similarity(pred_norm, gt_norm, dim=1) + pred_error = torch.clamp(pred_error, min=-1.0, max=1.0) + pred_error = torch.acos(pred_error) * 180.0 / np.pi + pred_error = pred_error.unsqueeze(1) # (B, 1, H, W) + return pred_error + + +def compute_normal_metrics(total_normal_errors): + total_normal_errors = total_normal_errors.detach().cpu().numpy() + num_pixels = total_normal_errors.shape[0] + + metrics = { + 'mean': np.average(total_normal_errors), + 'median': np.median(total_normal_errors), + 'rmse': np.sqrt(np.sum(total_normal_errors * total_normal_errors) / num_pixels), + 'a1': 100.0 * (np.sum(total_normal_errors < 5) / num_pixels), + 'a2': 100.0 * (np.sum(total_normal_errors < 7.5) / num_pixels), + 'a3': 100.0 * (np.sum(total_normal_errors < 11.25) / num_pixels), + 'a4': 100.0 * (np.sum(total_normal_errors < 22.5) / num_pixels), + 'a5': 100.0 * (np.sum(total_normal_errors < 30) / num_pixels) + } + + return metrics + + +def pad_input(orig_H, orig_W): + if orig_W % 32 == 0: + l = 0 + r = 0 + else: + new_W = 32 * ((orig_W // 32) + 1) + l = (new_W - orig_W) // 2 + r = (new_W - orig_W) - l + + if orig_H % 32 == 0: + t = 0 + b = 0 + else: + new_H = 32 * ((orig_H // 32) + 1) + t = (new_H - orig_H) // 2 + b = (new_H - orig_H) - t + return l, r, t, b + + +def get_intrins_from_fov(new_fov, H, W, device): + # NOTE: top-left pixel should be (0,0) + if W >= H: + new_fu = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0)) + new_fv = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0)) + else: + new_fu = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0)) + new_fv = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0)) + + new_cu = (W / 2.0) - 0.5 + new_cv = (H / 2.0) - 0.5 + + new_intrins = torch.tensor([ + [new_fu, 0, new_cu ], + [0, new_fv, new_cv ], + [0, 0, 1 ] + ], dtype=torch.float32, device=device) + + return new_intrins + + +def get_intrins_from_txt(intrins_path, device): + # NOTE: top-left pixel should be (0,0) + with open(intrins_path, 'r') as f: + intrins_ = f.readlines()[0].split()[0].split(',') + intrins_ = [float(i) for i in intrins_] + fx, fy, cx, cy = intrins_ + + intrins = torch.tensor([ + [fx, 0,cx], + [ 0,fy,cy], + [ 0, 0, 1] + ], dtype=torch.float32, device=device) + + return intrins \ No newline at end of file