-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdprapi.py
116 lines (94 loc) · 3.68 KB
/
dprapi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
'''
this is a simple test file
'''
import sys
sys.path.append('model')
sys.path.append('utils')
from utils_SH import *
# other modules
import os
import numpy as np
from torch.autograd import Variable
from torchvision.utils import make_grid
import torch
import time
import cv2
from defineHourglass_1024_gray_skip_matchFeature import *
class DPRAPI:
def __init__(
self,
images_dir,
output_dir
):
self.images_dir = images_dir
self.output_dir = output_dir
# ---------------- create normal for rendering half sphere ------
img_size = 256
x = np.linspace(-1, 1, img_size)
z = np.linspace(1, -1, img_size)
x, z = np.meshgrid(x, z)
mag = np.sqrt(x**2 + z**2)
self.valid = mag <=1
y = -np.sqrt(1 - (x*self.valid)**2 - (z*self.valid)**2)
x = x * self.valid
y = y * self.valid
z = z * self.valid
self.normal = np.concatenate((x[...,None], y[...,None], z[...,None]), axis=2)
self.normal = np.reshape(self.normal, (-1, 3))
#-----------------------------------------------------------------
modelFolder = 'trained_model/'
# load model
my_network_512 = HourglassNet(16)
my_network = HourglassNet_1024(my_network_512, 16)
my_network.load_state_dict(torch.load(os.path.join(modelFolder, 'trained_model_1024_03.t7')))
my_network.cuda()
my_network.train(False)
saveFolder = self.output_dir
if not os.path.exists(saveFolder):
os.makedirs(saveFolder)
self.my_network = my_network
def predict_light(self):
images = [f for f in os.listdir(self.images_dir) if f[0] not in '._']
lights = np.zeros([len(images), 1, 9, 1, 1])
i = 0
for img_name in images:
raw_img_path = os.path.join(self.images_dir, img_name)
print('predicting light for image: ', raw_img_path)
sh = self.predict_light_on_image(raw_img_path)
lights[i] = sh.detach().cpu().numpy()
i += 1
np.save(self.output_dir+'/lights.npy', lights)
def predict_light_on_image(self, image_url):
img = cv2.imread(image_url)
row, col, _ = img.shape
img = cv2.resize(img, (1024, 1024))
Lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
inputL = Lab[:,:,0]
inputL = inputL.astype(np.float32)/255.0
inputL = inputL.transpose((0,1))
inputL = inputL[None,None,...]
inputL = Variable(torch.from_numpy(inputL).cuda())
lightFolder = 'data/example_light/'
for i in range(1):
sh = np.loadtxt(os.path.join(lightFolder, 'rotate_light_{:02d}.txt'.format(i)))
sh = sh[0:9]
sh = sh * 0.7
# rendering half-sphere
sh = np.squeeze(sh)
shading = get_shading(self.normal, sh)
value = np.percentile(shading, 95)
ind = shading > value
shading[ind] = value
shading = (shading - np.min(shading))/(np.max(shading) - np.min(shading))
shading = (shading *255.0).astype(np.uint8)
shading = np.reshape(shading, (256, 256))
shading = shading * self.valid
cv2.imwrite(os.path.join(self.output_dir, \
'light_{:02d}.png'.format(i)), shading)
#----------------------------------------------
# rendering images using the network
#----------------------------------------------
sh = np.reshape(sh, (1,9,1,1)).astype(np.float32)
sh = Variable(torch.from_numpy(sh).cuda())
outputImg, _, outputSH, _ = self.my_network(inputL, sh, 0)
return outputSH