-
Notifications
You must be signed in to change notification settings - Fork 0
/
layers.py
338 lines (293 loc) · 12.8 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import dgl
import dgl.function as fn
import scipy.sparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.nn import AvgPooling, GraphConv, MaxPooling
from dgl.ops import edge_softmax
from torch import Tensor
from torch.nn import Parameter
from functions import edge_sparsemax
from utils import get_batch_id, topk
class WeightedGraphConv(GraphConv):
r"""
Description
-----------
GraphConv with edge weights on homogeneous graphs.
If edge weights are not given, directly call GraphConv instead.
Parameters
----------
graph : DGLGraph
The graph to perform this operation.
n_feat : torch.Tensor
The node features
e_feat : torch.Tensor, optional
The edge features. Default: :obj:`None`
"""
def forward(self, graph:DGLGraph, n_feat, e_feat=None):
if e_feat is None:
return super(WeightedGraphConv, self).forward(graph, n_feat)
with graph.local_scope():
if self.weight is not None:
n_feat = torch.matmul(n_feat, self.weight)
src_norm = torch.pow(graph.out_degrees().float().clamp(min=1), -0.5)
src_norm = src_norm.view(-1, 1)
dst_norm = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5)
dst_norm = dst_norm.view(-1, 1)
n_feat = n_feat * src_norm
graph.ndata["h"] = n_feat
graph.edata["e"] = e_feat
graph.update_all(fn.src_mul_edge("h", "e", "m"),
fn.sum("m", "h"))
n_feat = graph.ndata.pop("h")
n_feat = n_feat * dst_norm
if self.bias is not None:
n_feat = n_feat + self.bias
if self._activation is not None:
n_feat = self._activation(n_feat)
return n_feat
class NodeInfoScoreLayer(nn.Module):
r"""
Description
-----------
Compute a score for each node for sort-pooling. The score of each node
is computed via the absolute difference of its first-order random walk
result and its features.
Arguments
---------
sym_norm : bool, optional
If true, use symmetric norm for adjacency.
Default: :obj:`True`
Parameters
----------
graph : DGLGraph
The graph to perform this operation.
feat : torch.Tensor
The node features
e_feat : torch.Tensor, optional
The edge features. Default: :obj:`None`
Returns
-------
Tensor
Score for each node.
"""
def __init__(self, sym_norm:bool=True):
super(NodeInfoScoreLayer, self).__init__()
self.sym_norm = sym_norm
def forward(self, graph:dgl.DGLGraph, feat:Tensor, e_feat:Tensor):
with graph.local_scope():
if self.sym_norm:
src_norm = torch.pow(graph.out_degrees().float().clamp(min=1), -0.5)
src_norm = src_norm.view(-1, 1).to(feat.device)
dst_norm = torch.pow(graph.in_degrees().float().clamp(min=1), -0.5)
dst_norm = dst_norm.view(-1, 1).to(feat.device)
src_feat = feat * src_norm
graph.ndata["h"] = src_feat
graph.edata["e"] = e_feat
graph = dgl.remove_self_loop(graph)
graph.update_all(fn.src_mul_edge("h", "e", "m"), fn.sum("m", "h"))
dst_feat = graph.ndata.pop("h") * dst_norm
feat = feat - dst_feat
else:
dst_norm = 1. / graph.in_degrees().float().clamp(min=1)
dst_norm = dst_norm.view(-1, 1)
graph.ndata["h"] = feat
graph.edata["e"] = e_feat
graph = dgl.remove_self_loop(graph)
graph.update_all(fn.src_mul_edge("h", "e", "m"), fn.sum("m", "h"))
feat = feat - dst_norm * graph.ndata.pop("h")
score = torch.sum(torch.abs(feat), dim=1)
return score
class HGPSLPool(nn.Module):
r"""
Description
-----------
The HGP-SL pooling layer from
`Hierarchical Graph Pooling with Structure Learning <https://arxiv.org/pdf/1911.05954.pdf>`
Parameters
----------
in_feat : int
The number of input node feature's channels
ratio : float, optional
Pooling ratio. Default: 0.8
sample : bool, optional
Whether use k-hop union graph to increase efficiency.
Currently we only support full graph. Default: :obj:`False`
sym_score_norm : bool, optional
Use symmetric norm for adjacency or not. Default: :obj:`True`
sparse : bool, optional
Use edge sparsemax instead of edge softmax. Default: :obj:`True`
sl : bool, optional
Use structure learining module or not. Default: :obj:`True`
lamb : float, optional
The lambda parameter as weight of raw adjacency as described in the
HGP-SL paper. Default: 1.0
negative_slop : float, optional
Negative slop for leaky_relu. Default: 0.2
Returns
-------
DGLGraph
The pooled graph.
torch.Tensor
Node features
torch.Tensor
Edge features
torch.Tensor
Permutation index
"""
def __init__(self, in_feat:int, ratio=0.8, sample=True,
sym_score_norm=True, sparse=True, sl=True,
lamb=1.0, negative_slop=0.2, k_hop=3):
super(HGPSLPool, self).__init__()
self.in_feat = in_feat
self.ratio = ratio
self.sample = sample
self.sparse = sparse
self.sl = sl
self.lamb = lamb
self.negative_slop = negative_slop
self.k_hop = k_hop
self.att = Parameter(torch.Tensor(1, self.in_feat * 2))
self.calc_info_score = NodeInfoScoreLayer(sym_norm=sym_score_norm)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.att.data)
def forward(self, graph:DGLGraph, feat:Tensor, e_feat=None):
# top-k pool first
if e_feat is None:
e_feat = torch.ones((graph.number_of_edges(),),
dtype=feat.dtype, device=feat.device)
batch_num_nodes = graph.batch_num_nodes()
x_score = self.calc_info_score(graph, feat, e_feat)
perm, next_batch_num_nodes = topk(x_score, self.ratio,
get_batch_id(batch_num_nodes),
batch_num_nodes)
feat = feat[perm]
pool_graph = None
if not self.sample or not self.sl:
# pool graph
graph.edata["e"] = e_feat
pool_graph = dgl.node_subgraph(graph, perm)
e_feat = pool_graph.edata.pop("e")
pool_graph.set_batch_num_nodes(next_batch_num_nodes)
# no structure learning layer, directly return.
if not self.sl:
return pool_graph, feat, e_feat, perm
# Structure Learning
if self.sample:
# A fast mode for large graphs.
# In large graphs, learning the possible edge weights between each
# pair of nodes is time consuming. To accelerate this process,
# we sample it's K-Hop neighbors for each node and then learn the
# edge weights between them.
# first build multi-hop graph
row, col = graph.all_edges()
num_nodes = graph.num_nodes()
scipy_adj = scipy.sparse.coo_matrix((e_feat.detach().cpu(), (row.detach().cpu(), col.detach().cpu())), shape=(num_nodes, num_nodes))
for _ in range(self.k_hop):
two_hop = scipy_adj ** 2
two_hop = two_hop * (1e-5 / two_hop.max())
scipy_adj = two_hop + scipy_adj
row, col = scipy_adj.nonzero()
row = torch.tensor(row, dtype=torch.long, device=graph.device)
col = torch.tensor(col, dtype=torch.long, device=graph.device)
e_feat = torch.tensor(scipy_adj.data, dtype=torch.float, device=feat.device)
# perform pooling on multi-hop graph
mask = perm.new_full((num_nodes, ), -1)
i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
mask[perm] = i
row, col = mask[row], mask[col]
mask = (row >=0 ) & (col >= 0)
row, col = row[mask], col[mask]
e_feat = e_feat[mask]
# add remaining self loops
mask = row != col
num_nodes = perm.size(0) # num nodes after pool
loop_index = torch.arange(0, num_nodes, dtype=row.dtype, device=row.device)
inv_mask = ~mask
loop_weight = torch.full((num_nodes, ), 0, dtype=e_feat.dtype, device=e_feat.device)
remaining_e_feat = e_feat[inv_mask]
if remaining_e_feat.numel() > 0:
loop_weight[row[inv_mask]] = remaining_e_feat
e_feat = torch.cat([e_feat[mask], loop_weight], dim=0)
row, col = row[mask], col[mask]
row = torch.cat([row, loop_index], dim=0)
col = torch.cat([col, loop_index], dim=0)
# attention scores
weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(dim=-1)
weights = F.leaky_relu(weights, self.negative_slop) + e_feat * self.lamb
# sl and normalization
sl_graph = dgl.graph((row, col))
if self.sparse:
weights = edge_sparsemax(sl_graph, weights)
else:
weights = edge_softmax(sl_graph, weights)
# get final graph
mask = torch.abs(weights) > 0
row, col, weights = row[mask], col[mask], weights[mask]
pool_graph = dgl.graph((row, col))
pool_graph.set_batch_num_nodes(next_batch_num_nodes)
e_feat = weights
else:
# Learning the possible edge weights between each pair of
# nodes in the pooled subgraph, relative slower.
# construct complete graphs for all graph in the batch
# use dense to build, then transform to sparse.
# maybe there's more efficient way?
batch_num_nodes = next_batch_num_nodes
block_begin_idx = torch.cat([batch_num_nodes.new_zeros(1),
batch_num_nodes.cumsum(dim=0)[:-1]], dim=0)
block_end_idx = batch_num_nodes.cumsum(dim=0)
dense_adj = torch.zeros((pool_graph.num_nodes(),
pool_graph.num_nodes()),
dtype=torch.float,
device=feat.device)
for idx_b, idx_e in zip(block_begin_idx, block_end_idx):
dense_adj[idx_b:idx_e, idx_b:idx_e] = 1.
row, col = torch.nonzero(dense_adj).t().contiguous()
# compute weights for node-pairs
weights = (torch.cat([feat[row], feat[col]], dim=1) * self.att).sum(dim=-1)
weights = F.leaky_relu(weights, self.negative_slop)
dense_adj[row, col] = weights
# add pooled graph structure to weight matrix
pool_row, pool_col = pool_graph.all_edges()
dense_adj[pool_row, pool_col] += self.lamb * e_feat
weights = dense_adj[row, col]
del dense_adj
torch.cuda.empty_cache()
# edge softmax/sparsemax
complete_graph = dgl.graph((row, col))
if self.sparse:
weights = edge_sparsemax(complete_graph, weights)
else:
weights = edge_softmax(complete_graph, weights)
# get new e_feat and graph structure, clean up.
mask = torch.abs(weights) > 1e-9
row, col, weights = row[mask], col[mask], weights[mask]
e_feat = weights
pool_graph = dgl.graph((row, col))
pool_graph.set_batch_num_nodes(next_batch_num_nodes)
return pool_graph, feat, e_feat, perm
class ConvPoolReadout(torch.nn.Module):
"""A helper class. (GraphConv -> Pooling -> Readout)"""
def __init__(self, in_feat:int, out_feat:int, pool_ratio=0.8,
sample:bool=False, sparse:bool=True, sl:bool=True,
lamb:float=1., pool:bool=True):
super(ConvPoolReadout, self).__init__()
self.use_pool = pool
self.conv = WeightedGraphConv(in_feat, out_feat)
if pool:
self.pool = HGPSLPool(out_feat, ratio=pool_ratio, sparse=sparse,
sample=sample, sl=sl, lamb=lamb)
else:
self.pool = None
self.avgpool = AvgPooling()
self.maxpool = MaxPooling()
def forward(self, graph, feature, e_feat=None):
out = F.relu(self.conv(graph, feature, e_feat))
if self.use_pool:
graph, out, e_feat, _ = self.pool(graph, out, e_feat)
readout = torch.cat([self.avgpool(graph, out), self.maxpool(graph, out)], dim=-1)
return graph, out, e_feat, readout