-
Notifications
You must be signed in to change notification settings - Fork 4
/
data_solve.py
96 lines (82 loc) · 2.99 KB
/
data_solve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pickle
import torch
import numpy as np
from adjacent_matrix_norm import calculate_scaled_laplacian, calculate_symmetric_normalized_laplacian, calculate_symmetric_message_passing_adj, calculate_transition_matrix
def load_pkl(pickle_file: str) -> object:
"""Load pickle data.
Args:
pickle_file (str): file path
Returns:
object: loaded objected
"""
try:
with open(pickle_file, "rb") as f:
pickle_data = pickle.load(f)
except UnicodeDecodeError:
with open(pickle_file, "rb") as f:
pickle_data = pickle.load(f, encoding="latin1")
except Exception as e:
print("Unable to load data ", pickle_file, ":", e)
raise
return pickle_data
def dump_pkl(obj: object, file_path: str):
"""Dumplicate pickle data.
Args:
obj (object): object
file_path (str): file path
"""
with open(file_path, "wb") as f:
pickle.dump(obj, f)
def load_adj(file_path: str, adj_type: str):
"""load adjacency matrix.
Args:
file_path (str): file path
adj_type (str): adjacency matrix type
Returns:
list of numpy.matrix: list of preproceesed adjacency matrices
np.ndarray: raw adjacency matrix
"""
try:
# METR and PEMS_BAY
_, _, adj_mx = load_pkl(file_path)
except ValueError:
# PeMS04
adj_mx = load_pkl(file_path)
if adj_type == "scalap":
adj = [calculate_scaled_laplacian(adj_mx).astype(np.float32).todense()]
elif adj_type == "normlap":
adj = [calculate_symmetric_normalized_laplacian(
adj_mx).astype(np.float32).todense()]
elif adj_type == "symnadj":
adj = [calculate_symmetric_message_passing_adj(
adj_mx).astype(np.float32).todense()]
elif adj_type == "transition":
adj = [calculate_transition_matrix(adj_mx).T]
elif adj_type == "doubletransition":
adj = [calculate_transition_matrix(adj_mx).T, calculate_transition_matrix(adj_mx.T).T]
elif adj_type == "identity":
adj = [np.diag(np.ones(adj_mx.shape[0])).astype(np.float32)]
elif adj_type == "original":
adj = [adj_mx]
else:
error = 0
assert error, "adj type not defined"
return adj, adj_mx
def load_node2vec_emb(file_path: str) -> torch.Tensor:
"""load node2vec embedding
Args:
file_path (str): file path
Returns:
torch.Tensor: node2vec embedding
"""
# spatial embedding
with open(file_path, mode="r") as f:
lines = f.readlines()
temp = lines[0].split(" ")
num_vertex, dims = int(temp[0]), int(temp[1])
spatial_embeddings = torch.zeros((num_vertex, dims), dtype=torch.float32)
for line in lines[1:]:
temp = line.split(" ")
index = int(temp[0])
spatial_embeddings[index] = torch.Tensor([float(ch) for ch in temp[1:]])
return spatial_embeddings