-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_gui_utils.py
executable file
·96 lines (75 loc) · 3.68 KB
/
train_gui_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
import torch
class DeformKeypoints:
def __init__(self) -> None:
self.keypoints3d_list = [] # list of keypoints group
self.keypoints_idx_list = [] # keypoints index
self.keypoints3d_delta_list = []
self.selective_keypoints_idx_list = [] # keypoints index
self.idx2group = {}
self.selective_rotation_keypoints_idx_list = []
# self.rotation_idx2group = {}
def get_kpt_idx(self,):
return self.keypoints_idx_list
def get_kpt(self,):
return self.keypoints3d_list
def get_kpt_delta(self,):
return self.keypoints3d_delta_list
def get_deformed_kpt_np(self, rate=1.):
return np.array(self.keypoints3d_list) + np.array(self.keypoints3d_delta_list) * rate
def add_kpts(self, keypoints_coord, keypoints_idx, expand=False):
# keypoints3d: [N, 3], keypoints_idx: [N,], torch.tensor
# self.selective_keypoints_idx_list.clear()
selective_keypoints_idx_list = [] if not expand else self.selective_keypoints_idx_list
for idx in range(len(keypoints_idx)):
if not self.contain_kpt(keypoints_idx[idx].item()):
selective_keypoints_idx_list.append(len(self.keypoints_idx_list))
self.keypoints_idx_list.append(keypoints_idx[idx].item())
self.keypoints3d_list.append(keypoints_coord[idx].cpu().numpy())
self.keypoints3d_delta_list.append(np.zeros_like(self.keypoints3d_list[-1]))
for kpt_idx in keypoints_idx:
self.idx2group[kpt_idx.item()] = selective_keypoints_idx_list
self.selective_keypoints_idx_list = selective_keypoints_idx_list
def contain_kpt(self, idx):
# idx: int
if idx in self.keypoints_idx_list:
return True
else:
return False
def select_kpt(self, idx):
# idx: int
# output: idx list of this group
if idx in self.keypoints_idx_list:
self.selective_keypoints_idx_list = self.idx2group[idx]
def select_rotation_kpt(self, idx):
if idx in self.keypoints_idx_list:
self.selective_rotation_keypoints_idx_list = self.idx2group[idx]
def get_rotation_center(self,):
selected_rotation_points = self.get_deformed_kpt_np()[self.selective_rotation_keypoints_idx_list]
return selected_rotation_points.mean(axis=0)
def get_selective_center(self,):
selected_points = self.get_deformed_kpt_np()[self.selective_keypoints_idx_list]
return selected_points.mean(axis=0)
def delete_kpt(self, idx):
for kidx in self.selective_keypoints_idx_list:
list_idx = self.idx2group.pop(kidx)
self.keypoints3d_delta_list.pop(list_idx)
self.keypoints3d_list.pop(list_idx)
self.keypoints_idx_list.pop(list_idx)
def delete_batch_ktps(self, batch_idx):
pass
def update_delta(self, delta):
# delta: [3,], np.array
for idx in self.selective_keypoints_idx_list:
self.keypoints3d_delta_list[idx] += delta
def set_delta(self, delta):
# delta: [N, 3], np.array
for id, idx in enumerate(self.selective_keypoints_idx_list):
self.keypoints3d_delta_list[idx] = delta[id]
def set_rotation_delta(self, rot_mat):
kpts3d = self.get_deformed_kpt_np()[self.selective_keypoints_idx_list]
kpts3d_mean = kpts3d.mean(axis=0)
kpts3d = (kpts3d - kpts3d_mean) @ rot_mat.T + kpts3d_mean
delta = kpts3d - np.array(self.keypoints3d_list)[self.selective_keypoints_idx_list]
for id, idx in enumerate(self.selective_keypoints_idx_list):
self.keypoints3d_delta_list[idx] = delta[id]