-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_zf.py
131 lines (115 loc) · 5.75 KB
/
test_zf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import numpy as np
from simple_3dviz import Mesh,Scene
# from simple_3dviz.window import show
from simple_3dviz.utils import render
from simple_3dviz.behaviours.io import SaveFrames
from simple_3dviz.behaviours.movements import CameraTrajectory
from simple_3dviz.behaviours.trajectory import Circle
from simple_3dviz.utils import save_frame
from simple_3dviz import Lines
import argparse
from zf_utils import query_arrays,make_init_dict, get_networks, voxel_save
import clip
from networks import autoencoder, latent_flows
import torch
from train_post_clip import get_clip_model
from utils import visualization
import tqdm
import os
from test_post_clip import id_to_sub_category
def save_voxels(voxels,out_path):
# Load your voxel data as a NumPy array
voxels = voxels.astype(np.bool8) # Replace with the path to your voxel data
# create directory for out_path if it doesn't exist
if not os.path.exists(out_path):
os.makedirs(out_path)
# Create a scene with a mesh and a line
print('image out path:',out_path)
for i in range(voxels.shape[0]):
voxel_save(voxels[i], None, out_file=out_path+"/%i.png")
def get_networks(checkpoint_dir,init,zero_conv,args,iter=15000):
init_dict = make_init_dict()[init]
args.emb_dims = init_dict["emb_dim"]
args.num_blocks = init_dict["num_blocks"]
args.num_hidden = init_dict["num_hidden"]
args.emb_dims = init_dict["emb_dim"]
args.encoder_type = "Voxel_Encoder_BN"
args.decoder_type = "Occ_Simple_Decoder"
args.input_type = "Voxel"
args.output_type = "Implicit"
args.cond_emb_dim, args.device = ( 512, 'cuda:0')
args.flow_type = "realnvp_half"
net = autoencoder.EncoderWrapper(args).to(args.device)
latent_flow_network = latent_flows.get_generator(args.emb_dims, args.cond_emb_dim, args.device, flow_type=args.flow_type, num_blocks=args.num_blocks, num_hidden=args.num_hidden)
if False and zero_conv:
net.encoder.decoder = autoencoder.ZeroConvDecoder(net.encoder.decoder)
net = net.to(args.device)
checkpoint_nf_path = checkpoint_dir + "flow_model_%s.pt" % str(iter)
# checkpoint_nf_path = "/scratch/km3888/inits/models/prior/best.pt"
checkpoint = torch.load(checkpoint_nf_path, map_location=args.device)
# checkpoint = {k[18:]:v for k,v in checkpoint.items() if k.startswith('latent_flow_model')}
latent_flow_network.load_state_dict(checkpoint)
checkpoint_path = checkpoint_dir + "aencoder_%s.pt" % str(iter)
# checkpoint_path = "/scratch/km3888/inits/models/autoencoder/best_iou.pt"
checkpoint = torch.load(checkpoint_path,map_location=args.device)
# checkpoint = {k[8:]:v for k,v in checkpoint.items()}
net.load_state_dict(checkpoint)
net.eval()
args.clip_model_type = 'B-32'
args,clip_model = get_clip_model(args)
#calculate total parameters in autoencoder and latent flow
return net, latent_flow_network, clip_model
def gen_voxels(query_array, net, latent_flow_model,clip_model, output_dir):
net.eval()
latent_flow_model.eval()
clip_model.eval()
count = 1
num_figs = 1
voxels = []
raw_voxels = []
with torch.no_grad():
voxel_size = 100
shape = (voxel_size, voxel_size, voxel_size)
p = visualization.make_3d_grid([-0.5] * 3, [+0.5] * 3, shape).type(torch.FloatTensor).to(args.device)
query_points = p.expand(num_figs, *p.size())
for text_in in query_array:
##########
text = clip.tokenize([text_in]).to(args.device)
text_features = clip_model.encode_text(text)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
###########
torch.manual_seed(5)
mean_shape = torch.zeros(1, args.emb_dims).to(args.device)
noise = torch.Tensor(num_figs-1, args.emb_dims).normal_().to(args.device)
noise = torch.clip(noise, min=-1, max=1)
noise = torch.cat([mean_shape, noise], dim=0)
decoder_embs = latent_flow_model.sample("cuda:0",num_samples=num_figs, noise=noise, cond_inputs=text_features.repeat(num_figs,1))
out = net.encoder.decoding(decoder_embs, query_points)
raw_voxels.append(out.view(num_figs, voxel_size, voxel_size, voxel_size).detach().cpu().numpy().squeeze())
voxels_out = (out.view(num_figs, voxel_size, voxel_size, voxel_size) > args.threshold).detach().cpu().numpy()
voxels.append(voxels_out.squeeze())
voxels = np.stack(voxels,axis=0)
return voxels,raw_voxels
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_dir", type=str)
parser.add_argument("--output_dir",type=str)
parser.add_argument("--threshold", type=float, default=0.05)
args = parser.parse_args()
# get query array name from checkpoint_dir name
# for example, q=object_5_with_original_lr=1e-05_beta=200 gives query array name object_5_with_original
query_array_name = args.checkpoint_dir.split("=")[1][:-3]
query_array = query_arrays[query_array_name]
zero_conv = "zero_conv" in args.checkpoint_dir
# query_array = ["a mushroom"]
iter=5000
net, flow, clip_model = get_networks(args.checkpoint_dir, args.init, zero_conv,args,iter=iter)
for thresh in [0.05]:
args.threshold = thresh
voxels,raw_voxels = gen_voxels(query_array, net, flow,clip_model, args.checkpoint_dir)
if not os.path.exists("voxels_%s/"%iter+args.output_dir):
os.makedirs("voxels_%s/" % iter +args.output_dir)
for i in range(voxels.shape[0]):
print('voxel out path:',args.output_dir)
np.save("voxels_%s/"%iter+args.output_dir+"/%s_%s.npy" % (query_array[i],thresh),voxels[i])
save_voxels(voxels,args.output_dir,query_array,thresh)