forked from zouxiaochuan/code_ogblsc2022
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon_utils.py
64 lines (50 loc) · 1.6 KB
/
common_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pickle
import numpy as np
def save_obj(obj, name):
with open(name, 'wb') as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
pass
def load_obj(name):
with open(name, 'rb') as f:
return pickle.load(f)
pass
def collate_seq(feat_list):
batch_size = len(feat_list)
feat_max_len = np.max([feat.shape[0] for feat in feat_list])
feat_dim = feat_list[0].shape[1]
feat = np.zeros(
(batch_size, feat_max_len, feat_dim),
dtype=feat_list[0].dtype)
mask = np.zeros((batch_size, feat_max_len), dtype='float32')
for i, ifeat in enumerate(feat_list):
size = ifeat.shape[0]
feat[i, :size, :] = ifeat
mask[i, :size] = 1
pass
return feat, mask
def collate_map(feat_list):
batch_size = len(feat_list)
feat_max_len = np.max([feat.shape[0] for feat in feat_list])
feat_dim = feat_list[0].shape[2]
feat = np.zeros(
(batch_size, feat_max_len, feat_max_len, feat_dim),
dtype=feat_list[0].dtype)
for i, ifeat in enumerate(feat_list):
size = ifeat.shape[0]
feat[i, :size, :size, :] = ifeat
pass
return feat
pass
def collate_cube(feat_list):
batch_size = len(feat_list)
feat_max_len = np.max([feat.shape[0] for feat in feat_list])
feat_dim = feat_list[0].shape[3]
feat = np.zeros(
(batch_size, feat_max_len, feat_max_len, feat_max_len, feat_dim),
dtype=feat_list[0].dtype)
for i, ifeat in enumerate(feat_list):
size = ifeat.shape[0]
feat[i, :size, :size, :size, :] = ifeat
pass
return feat
pass