-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_funcs.py
110 lines (97 loc) · 4 KB
/
test_funcs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import copy
from tqdm import tqdm
import numpy as np
import os
from configure.cfgs import cfg
from utils_SH import *
def unnormal(input, mean, std):
output = input[:,:-1,:]*std + mean
return torch.cat((output, input[:,-1:,:]), 1)
def normal(input, mean, std):
output = (input[:,:-1,:] - mean) / std
return torch.cat((output, input[:,-1:,:]), 1)
def test_autoencoder_dataloader(device, model, dataloader_test, shapedata, J_regressor, mm_constant = 1000, unnormal_flag = False):
model.eval()
l1_loss = 0
l2_loss = 0
J_regressor = torch.from_numpy(J_regressor.astype(np.float32)).to(device)
with torch.no_grad():
for i, sample_dict in enumerate(tqdm(dataloader_test)):
tx = sample_dict['verts'].to(device)
prediction, z = model(tx)
if i==0:
predictions = copy.deepcopy(prediction)
else:
predictions = torch.cat([predictions,prediction],0)
if i==0:
z_s = copy.deepcopy(z)
else:
z_s = torch.cat([z_s, z],0)
if i==0:
tx_s = copy.deepcopy(tx)
else:
tx_s = torch.cat([tx_s, tx],0)
if dataloader_test.dataset.dummy_node:
x_recon = prediction[:,:-1]
x = tx[:,:-1]
else:
x_recon = prediction
x = tx
l1_loss+= torch.mean(torch.abs(x_recon-x))*x.shape[0]/float(len(dataloader_test.dataset))
x_recon = (x_recon) * mm_constant
x = (x) * mm_constant
l2_loss+= torch.mean(torch.sqrt(torch.sum((x_recon - x)**2,dim=2)))*x.shape[0]/float(len(dataloader_test.dataset))
predictions = predictions.cpu().numpy()
z_s = z_s.cpu().numpy()
tx_s = tx_s.cpu().numpy()
l1_loss = l1_loss.item()
l2_loss = l2_loss.item()
return predictions, z_s, tx_s, l1_loss, l2_loss
def test_autoencoder_dataloader_nonormal(device, model, dataloader_test, shapedata, J_regressor, mm_constant = 1000, unnormal_flag = False):
kps_keep = list(range(len(cfg.CONSTANTS.newskl_list) + 4))
if cfg.TRAIN.kpskeep_flag:
for i in [3,13,14]:
kps_keep.remove(i)
model.eval()
l1_loss = 0
l2_loss = 0
J_regressor = torch.from_numpy(J_regressor.astype(np.float32)).to(device)
with torch.no_grad():
for i, sample_dict in enumerate(tqdm(dataloader_test)):
tx = sample_dict['verts'].to(device)
kps_GT = torch.matmul(J_regressor, tx[:, :-1, :]).float()
prediction, z, z_kps = model(tx, kps_GT[:, kps_keep])
if i==0:
predictions = copy.deepcopy(prediction)
else:
predictions = torch.cat([predictions,prediction],0)
if i==0:
z_s = copy.deepcopy(z)
else:
z_s = torch.cat([z_s, z],0)
if i==0:
z_kps_s = copy.deepcopy(z_kps)
else:
z_kps_s = torch.cat([z_kps_s, z_kps],0)
if i==0:
tx_s = copy.deepcopy(tx)
else:
tx_s = torch.cat([tx_s, tx],0)
if dataloader_test.dataset.dummy_node:
x_recon = prediction[:,:-1]
x = tx[:,:-1]
else:
x_recon = prediction
x = tx
l1_loss+= torch.mean(torch.abs(x_recon-x))*x.shape[0]/float(len(dataloader_test.dataset))
x_recon = (x_recon) * mm_constant
x = (x) * mm_constant
l2_loss+= torch.mean(torch.sqrt(torch.sum((x_recon - x)**2,dim=2)))*x.shape[0]/float(len(dataloader_test.dataset))
predictions = predictions.cpu().numpy()
z_s = z_s.cpu().numpy()
z_kps_s = z_kps_s.cpu().numpy()
tx_s = tx_s.cpu().numpy()
l1_loss = l1_loss.item()
l2_loss = l2_loss.item()
return predictions, z_s, z_kps_s, tx_s, l1_loss, l2_loss