-
Notifications
You must be signed in to change notification settings - Fork 82
/
utils.py
98 lines (86 loc) · 4.01 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import networkx as nx
import numpy as np
class DBLPDataLoader:
def __init__(self, graph_file):
self.g = nx.read_gpickle(graph_file)
self.num_of_nodes = self.g.number_of_nodes()
self.num_of_edges = self.g.number_of_edges()
self.edges_raw = self.g.edges(data=True)
self.nodes_raw = self.g.nodes(data=True)
self.edge_distribution = np.array([attr['weight'] for _, _, attr in self.edges_raw], dtype=np.float32)
self.edge_distribution /= np.sum(self.edge_distribution)
self.edge_sampling = AliasSampling(prob=self.edge_distribution)
self.node_negative_distribution = np.power(
np.array([self.g.degree(node, weight='weight') for node, _ in self.nodes_raw], dtype=np.float32), 0.75)
self.node_negative_distribution /= np.sum(self.node_negative_distribution)
self.node_sampling = AliasSampling(prob=self.node_negative_distribution)
self.node_index = {}
self.node_index_reversed = {}
for index, (node, _) in enumerate(self.nodes_raw):
self.node_index[node] = index
self.node_index_reversed[index] = node
self.edges = [(self.node_index[u], self.node_index[v]) for u, v, _ in self.edges_raw]
def fetch_batch(self, batch_size=16, K=10, edge_sampling='atlas', node_sampling='atlas'):
if edge_sampling == 'numpy':
edge_batch_index = np.random.choice(self.num_of_edges, size=batch_size, p=self.edge_distribution)
elif edge_sampling == 'atlas':
edge_batch_index = self.edge_sampling.sampling(batch_size)
elif edge_sampling == 'uniform':
edge_batch_index = np.random.randint(0, self.num_of_edges, size=batch_size)
u_i = []
u_j = []
label = []
for edge_index in edge_batch_index:
edge = self.edges[edge_index]
if self.g.__class__ == nx.Graph:
if np.random.rand() > 0.5: # important: second-order proximity is for directed edge
edge = (edge[1], edge[0])
u_i.append(edge[0])
u_j.append(edge[1])
label.append(1)
for i in range(K):
while True:
if node_sampling == 'numpy':
negative_node = np.random.choice(self.num_of_nodes, p=self.node_negative_distribution)
elif node_sampling == 'atlas':
negative_node = self.node_sampling.sampling()
elif node_sampling == 'uniform':
negative_node = np.random.randint(0, self.num_of_nodes)
if not self.g.has_edge(self.node_index_reversed[negative_node], self.node_index_reversed[edge[0]]):
break
u_i.append(edge[0])
u_j.append(negative_node)
label.append(-1)
return u_i, u_j, label
def embedding_mapping(self, embedding):
return {node: embedding[self.node_index[node]] for node, _ in self.nodes_raw}
class AliasSampling:
# Reference: https://en.wikipedia.org/wiki/Alias_method
def __init__(self, prob):
self.n = len(prob)
self.U = np.array(prob) * self.n
self.K = [i for i in range(len(prob))]
overfull, underfull = [], []
for i, U_i in enumerate(self.U):
if U_i > 1:
overfull.append(i)
elif U_i < 1:
underfull.append(i)
while len(overfull) and len(underfull):
i, j = overfull.pop(), underfull.pop()
self.K[j] = i
self.U[i] = self.U[i] - (1 - self.U[j])
if self.U[i] > 1:
overfull.append(i)
elif self.U[i] < 1:
underfull.append(i)
def sampling(self, n=1):
x = np.random.rand(n)
i = np.floor(self.n * x)
y = self.n * x - i
i = i.astype(np.int32)
res = [i[k] if y[k] < self.U[i[k]] else self.K[i[k]] for k in range(n)]
if n == 1:
return res[0]
else:
return res