-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_data.py
116 lines (116 loc) · 4.4 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import cv2
import json
import torch
import numpy
import imageio
trans_t = lambda t : torch.Tensor([
[1,0,0,0],
[0,1,0,0],
[0,0,1,t],
[0,0,0,1]]).float()
rot_phi = lambda phi : torch.Tensor([
[1,0,0,0],
[0,numpy.cos(phi),-numpy.sin(phi),0],
[0,numpy.sin(phi),numpy.cos(phi),0],
[0,0,0,1]]).float()
rot_theta = lambda th : torch.Tensor([
[numpy.cos(th),0,-numpy.sin(th),0],
[0,1,0,0],
[numpy.sin(th),0,numpy.cos(th),0],
[0,0,0,1]]).float()
def pose_spherical(theta,phi,radius):
c2w = trans_t(radius)
c2w = rot_phi(phi/180.*numpy.pi) @ c2w
c2w = rot_theta(theta/180.*numpy.pi) @ c2w
c2w = torch.Tensor(numpy.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
return c2w
def load_blender_data(basedir,half_res=False,testskip=1):
splits = ['train','val','test']
metas = {}
for s in splits:
with open(os.path.join(basedir,'transforms_{}.json'.format(s)),'r') as fp:
metas[s] = json.load(fp)
all_imgs = []
all_poses = []
counts = [0]
for s in splits:
meta = metas[s]
imgs = []
poses = []
if s=='train' or testskip==0:
skip = 1
else:
skip = testskip
for frame in meta['frames'][::skip]:
fname = os.path.join(basedir,frame['file_path'] + '.png')
imgs.append(imageio.imread(fname))
poses.append(numpy.array(frame['transform_matrix']))
imgs = (numpy.array(imgs) / 255.0).astype(numpy.float32) # keep all 4 channels (RGBA)
poses = numpy.array(poses).astype(numpy.float32)
counts.append(counts[-1] + imgs.shape[0])
all_imgs.append(imgs)
all_poses.append(poses)
i_split = [numpy.arange(counts[i],counts[i+1]) for i in range(3)]
imgs = numpy.concatenate(all_imgs,0)
poses = numpy.concatenate(all_poses,0)
H,W = imgs[0].shape[:2]
camera_angle_x = float(meta['camera_angle_x'])
focal = 0.5 * W / numpy.tan(0.5 * camera_angle_x)
render_poses = torch.stack([pose_spherical(angle,-30.0,4.0) for angle in numpy.linspace(-180,180,40+1)[:-1]],0)
if half_res:
H = H//2
W = W//2
focal = focal/2.0
imgs_half_res = numpy.zeros((imgs.shape[0],H,W,4))
for i,img in enumerate(imgs):
imgs_half_res[i] = cv2.resize(img,(W,H),interpolation=cv2.INTER_AREA)
imgs = imgs_half_res
return imgs,poses,render_poses,[H,W,focal],i_split
def load_linemod_data(basedir,half_res=False,testskip=1):
splits = ['train','val','test']
metas = {}
for s in splits:
with open(os.path.join(basedir,'transforms_{}.json'.format(s)),'r') as fp:
metas[s] = json.load(fp)
all_imgs = []
all_poses = []
counts = [0]
for s in splits:
meta = metas[s]
imgs = []
poses = []
if s=='train' or testskip==0:
skip = 1
else:
skip = testskip
for idx_test,frame in enumerate(meta['frames'][::skip]):
fname = frame['file_path']
if s == 'test':
print(f"{idx_test}th test frame: {fname}")
imgs.append(imageio.imread(fname))
poses.append(numpy.array(frame['transform_matrix']))
imgs = (numpy.array(imgs) / 255.0).astype(numpy.float32) # keep all 4 channels (RGBA)
poses = numpy.array(poses).astype(numpy.float32)
counts.append(counts[-1] + imgs.shape[0])
all_imgs.append(imgs)
all_poses.append(poses)
i_split = [numpy.arange(counts[i],counts[i+1]) for i in range(3)]
imgs = numpy.concatenate(all_imgs,0)
poses = numpy.concatenate(all_poses,0)
H,W = imgs[0].shape[:2]
focal = float(meta['frames'][0]['intrinsic_matrix'][0][0])
K = meta['frames'][0]['intrinsic_matrix']
print(f"Focal: {focal}")
render_poses = torch.stack([pose_spherical(angle,-30.0,4.0) for angle in numpy.linspace(-180,180,40+1)[:-1]],0)
if half_res:
H = H//2
W = W//2
focal = focal/2.0
imgs_half_res = numpy.zeros((imgs.shape[0],H,W,3))
for i,img in enumerate(imgs):
imgs_half_res[i] = cv2.resize(img,(W,H),interpolation=cv2.INTER_AREA)
imgs = imgs_half_res
near = numpy.floor(min(metas['train']['near'],metas['test']['near']))
far = numpy.ceil(max(metas['train']['far'],metas['test']['far']))
return imgs,poses,render_poses,[H,W,focal],K,i_split,near,far