-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprepare_data.py
127 lines (114 loc) · 4.16 KB
/
prepare_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import trimesh
import numpy as np
from tqdm import tqdm
import h5py
import difflib
import pandas as pd
from collections import defaultdict
#required files
# put data in WORKING_DIR
# model.csv, contains all model_ids
# split.txt, contains splits
# part_index, contains (model_id, part) pairs
# output dir
meta_dir = "../../metadata/"
WORKING_DIR = "../BPNet/dataset/processed_models_v5/" #TODO: Change to the glb folders
# number of sampled points
N_POINTS_PER_PART = 5000
def pc_normalize(pc):
# Center and rescale point for 1m radius
pmin = np.min(pc, axis=0)
pmax = np.max(pc, axis=0)
pc -= (pmin + pmax) / 2
scale = np.max(np.linalg.norm(pc, axis=1))
pc *= 1.0 / scale
return pc
# read all model ids
models=pd.read_csv(os.path.join(meta_dir, 'model.csv'))
model_ids = defaultdict(list)
all_ids=[]
# read split file
with open(os.path.join('./', "split.txt"), "r") as f:
for line in f:
ids, label = line.rstrip().split(',')
model_ids[label].append(ids)
all_ids.append(ids)
data_paths_train = list(set(model_ids['train']))
data_paths_vaild = list(set(model_ids['valid']))
data_paths_test = list(set(model_ids['test']))
# all data
dd=data_paths_train+data_paths_vaild+data_paths_test
f = open(meta_dir + 'parts.json')
_ALL_PARTS = json.load(f)
# parts index and reversed index
classes = dict(zip(_ALL_PARTS, range(len(_ALL_PARTS))))
def save_points(split):
sample_xyzs = []
sample_colorss = []
sample_segments = []
valid_ids = []
data_paths = list(set(model_ids[split]))
num_data=len(data_paths)
for i, model_id in tqdm(enumerate(data_paths)):
gltf_path = '{}/{}.glb'.format(WORKING_DIR, model_id)
try:
os.stat(gltf_path)
except:
gltf_path = '{}/{}.gltf'.format(WORKING_DIR, model_id)
try:
os.stat(gltf_path)
except:
print("model can't found {}".format(model_id))
continue
try:
mesh = trimesh.load(gltf_path)
except:
print("Error in reading")
continue
if len(mesh.geometry.items()) <= 1:
print("Error files{}: ".format(model_id))
continue
v = []
segment = []
for g_name, g_mesh in mesh.geometry.items():
g_name = g_name.lower()
if g_name in classes:
# Glb name is same as defined
part_name = g_name
else:
# If there are still some incorrect one.
part_name = g_name.split('_')[0]
if part_name not in classes:
part_name = difflib.get_close_matches(g_name, list(part_to_idx.keys()))[0]
# Add the vertex
v.append(g_mesh)
# Add the segmentation Labels
segment.append(np.full(g_mesh.faces.shape[0], classes[part_name]))
combined = trimesh.util.concatenate(v)
sample_xyz, sample_id = trimesh.sample.sample_surface(combined, count=5000)
sample_xyz = pc_normalize(sample_xyz)
# If there are no style models, color info set as zero
sample_colors = np.zeros_like(sample_xyz)
sample_segment = np.concatenate(segment)[sample_id]
sample_xyzs.append(sample_xyz)
sample_colorss.append(sample_colors)
sample_segments.append(sample_segment)
valid_ids.append(model_id)
x = np.stack(sample_xyzs)
y = np.stack(sample_colorss)
z = np.stack(sample_segments)
asciiList = [(n.split('/')[-1]).encode("ascii", "ignore") for n in valid_ids]
print(x.shape)
print(y.shape)
print(z.shape)
print(len(asciiList))
num_data=len(x)# new
with h5py.File('{}.hdf5'.format(split), 'w') as hf:
hf.create_dataset('pc', data=x, shape=(num_data, 5000, 3), compression='gzip', chunks=True)
hf.create_dataset('color', data=y, shape=(num_data, 5000, 3), compression='gzip', chunks=True)
hf.create_dataset('seg', data=z, shape=(num_data, 5000), compression='gzip', chunks=True)
hf.create_dataset('id', data=asciiList, shape=(num_data, 1), compression='gzip', chunks=True)
save_points('train')
save_points('valid')
# save_points('test')