forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhetero_rgcn.py
464 lines (379 loc) · 13 KB
/
hetero_rgcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
import argparse
import itertools
import sys
import dgl
import dgl.nn as dglnn
import psutil
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import AddReverse, Compose, ToSimple
from dgl.nn import HeteroEmbedding
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from tqdm import tqdm
v_t = dgl.__version__
def prepare_data(args, device):
dataset = DglNodePropPredDataset(name="ogbn-mag")
split_idx = dataset.get_idx_split()
# graph: dgl graph object, label: torch tensor of shape (num_nodes, num_tasks)
g, labels = dataset[0]
labels = labels["paper"].flatten()
transform = Compose([ToSimple(), AddReverse()])
g = transform(g)
print("Loaded graph: {}".format(g))
logger = Logger(args.runs)
# train sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 20])
num_workers = args.num_workers
train_loader = dgl.dataloading.DataLoader(
g,
split_idx["train"],
sampler,
batch_size=1024,
shuffle=True,
num_workers=num_workers,
device=device,
)
return g, labels, dataset.num_classes, split_idx, logger, train_loader
def extract_embed(node_embed, input_nodes):
emb = node_embed(
{ntype: input_nodes[ntype] for ntype in input_nodes if ntype != "paper"}
)
return emb
def rel_graph_embed(graph, embed_size):
node_num = {}
for ntype in graph.ntypes:
if ntype == "paper":
continue
node_num[ntype] = graph.num_nodes(ntype)
embeds = HeteroEmbedding(node_num, embed_size)
return embeds
class RelGraphConvLayer(nn.Module):
def __init__(
self, in_feat, out_feat, ntypes, rel_names, activation=None, dropout=0.0
):
super(RelGraphConvLayer, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.ntypes = ntypes
self.rel_names = rel_names
self.activation = activation
self.conv = dglnn.HeteroGraphConv(
{
rel: dglnn.GraphConv(
in_feat, out_feat, norm="right", weight=False, bias=False
)
for rel in rel_names
}
)
self.weight = nn.ModuleDict(
{
rel_name: nn.Linear(in_feat, out_feat, bias=False)
for rel_name in self.rel_names
}
)
# weight for self loop
self.loop_weights = nn.ModuleDict(
{
ntype: nn.Linear(in_feat, out_feat, bias=True)
for ntype in self.ntypes
}
)
self.dropout = nn.Dropout(dropout)
self.reset_parameters()
def reset_parameters(self):
for layer in self.weight.values():
layer.reset_parameters()
for layer in self.loop_weights.values():
layer.reset_parameters()
def forward(self, g, inputs):
"""
Parameters
----------
g : DGLGraph
Input graph.
inputs : dict[str, torch.Tensor]
Node feature for each node type.
Returns
-------
dict[str, torch.Tensor]
New node features for each node type.
"""
g = g.local_var()
wdict = {
rel_name: {"weight": self.weight[rel_name].weight.T}
for rel_name in self.rel_names
}
inputs_dst = {
k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
}
hs = self.conv(g, inputs, mod_kwargs=wdict)
def _apply(ntype, h):
h = h + self.loop_weights[ntype](inputs_dst[ntype])
if self.activation:
h = self.activation(h)
return self.dropout(h)
return {ntype: _apply(ntype, h) for ntype, h in hs.items()}
class EntityClassify(nn.Module):
def __init__(self, g, in_dim, out_dim):
super(EntityClassify, self).__init__()
self.in_dim = in_dim
self.h_dim = 64
self.out_dim = out_dim
self.rel_names = list(set(g.etypes))
self.rel_names.sort()
self.dropout = 0.5
self.layers = nn.ModuleList()
# i2h
self.layers.append(
RelGraphConvLayer(
self.in_dim,
self.h_dim,
g.ntypes,
self.rel_names,
activation=F.relu,
dropout=self.dropout,
)
)
# h2o
self.layers.append(
RelGraphConvLayer(
self.h_dim,
self.out_dim,
g.ntypes,
self.rel_names,
activation=None,
)
)
def reset_parameters(self):
for layer in self.layers:
layer.reset_parameters()
def forward(self, h, blocks):
for layer, block in zip(self.layers, blocks):
h = layer(block, h)
return h
class Logger(object):
r"""
This class was taken directly from the PyG implementation and can be found
here: https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/mag/logger.py
This was done to ensure that performance was measured in precisely the same way
"""
def __init__(self, runs):
self.results = [[] for _ in range(runs)]
def add_result(self, run, result):
assert len(result) == 3
assert run >= 0 and run < len(self.results)
self.results[run].append(result)
def print_statistics(self, run=None):
if run is not None:
result = 100 * th.tensor(self.results[run])
argmax = result[:, 1].argmax().item()
print(f"Run {run + 1:02d}:")
print(f"Highest Train: {result[:, 0].max():.2f}")
print(f"Highest Valid: {result[:, 1].max():.2f}")
print(f" Final Train: {result[argmax, 0]:.2f}")
print(f" Final Test: {result[argmax, 2]:.2f}")
else:
result = 100 * th.tensor(self.results)
best_results = []
for r in result:
train1 = r[:, 0].max().item()
valid = r[:, 1].max().item()
train2 = r[r[:, 1].argmax(), 0].item()
test = r[r[:, 1].argmax(), 2].item()
best_results.append((train1, valid, train2, test))
best_result = th.tensor(best_results)
print(f"All runs:")
r = best_result[:, 0]
print(f"Highest Train: {r.mean():.2f} ± {r.std():.2f}")
r = best_result[:, 1]
print(f"Highest Valid: {r.mean():.2f} ± {r.std():.2f}")
r = best_result[:, 2]
print(f" Final Train: {r.mean():.2f} ± {r.std():.2f}")
r = best_result[:, 3]
print(f" Final Test: {r.mean():.2f} ± {r.std():.2f}")
def train(
g,
model,
node_embed,
optimizer,
train_loader,
split_idx,
labels,
logger,
device,
run,
):
print("start training...")
category = "paper"
for epoch in range(3):
num_train = split_idx["train"][category].shape[0]
pbar = tqdm(total=num_train)
pbar.set_description(f"Epoch {epoch:02d}")
model.train()
total_loss = 0
for input_nodes, seeds, blocks in train_loader:
blocks = [blk.to(device) for blk in blocks]
seeds = seeds[
category
] # we only predict the nodes with type "category"
batch_size = seeds.shape[0]
input_nodes_indexes = input_nodes["paper"].to(g.device)
seeds = seeds.to(labels.device)
emb = extract_embed(node_embed, input_nodes)
# Add the batch's raw "paper" features
emb.update({"paper": g.ndata["feat"]["paper"][input_nodes_indexes]})
emb = {k: e.to(device) for k, e in emb.items()}
lbl = labels[seeds].to(device)
optimizer.zero_grad()
logits = model(emb, blocks)[category]
y_hat = logits.log_softmax(dim=-1)
loss = F.nll_loss(y_hat, lbl)
loss.backward()
optimizer.step()
total_loss += loss.item() * batch_size
pbar.update(batch_size)
pbar.close()
loss = total_loss / num_train
result = test(g, model, node_embed, labels, device, split_idx)
logger.add_result(run, result)
train_acc, valid_acc, test_acc = result
print(
f"Run: {run + 1:02d}, "
f"Epoch: {epoch +1 :02d}, "
f"Loss: {loss:.4f}, "
f"Train: {100 * train_acc:.2f}%, "
f"Valid: {100 * valid_acc:.2f}%, "
f"Test: {100 * test_acc:.2f}%"
)
return logger
@th.no_grad()
def test(g, model, node_embed, y_true, device, split_idx):
model.eval()
category = "paper"
evaluator = Evaluator(name="ogbn-mag")
# 2 GNN layers
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
loader = dgl.dataloading.DataLoader(
g,
{"paper": th.arange(g.num_nodes("paper"))},
sampler,
batch_size=16384,
shuffle=False,
num_workers=0,
device=device,
)
pbar = tqdm(total=y_true.size(0))
pbar.set_description(f"Inference")
y_hats = list()
for input_nodes, seeds, blocks in loader:
blocks = [blk.to(device) for blk in blocks]
seeds = seeds[
category
] # we only predict the nodes with type "category"
batch_size = seeds.shape[0]
input_nodes_indexes = input_nodes["paper"].to(g.device)
emb = extract_embed(node_embed, input_nodes)
# Get the batch's raw "paper" features
emb.update({"paper": g.ndata["feat"]["paper"][input_nodes_indexes]})
emb = {k: e.to(device) for k, e in emb.items()}
logits = model(emb, blocks)[category]
y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True)
y_hats.append(y_hat.cpu())
pbar.update(batch_size)
pbar.close()
y_pred = th.cat(y_hats, dim=0)
y_true = th.unsqueeze(y_true, 1)
train_acc = evaluator.eval(
{
"y_true": y_true[split_idx["train"]["paper"]],
"y_pred": y_pred[split_idx["train"]["paper"]],
}
)["acc"]
valid_acc = evaluator.eval(
{
"y_true": y_true[split_idx["valid"]["paper"]],
"y_pred": y_pred[split_idx["valid"]["paper"]],
}
)["acc"]
test_acc = evaluator.eval(
{
"y_true": y_true[split_idx["test"]["paper"]],
"y_pred": y_pred[split_idx["test"]["paper"]],
}
)["acc"]
return train_acc, valid_acc, test_acc
def is_support_affinity(v_t):
# dgl supports enable_cpu_affinity since 0.9.1
return v_t >= "0.9.1"
def main(args):
device = f"cuda:0" if th.cuda.is_available() else "cpu"
g, labels, num_classes, split_idx, logger, train_loader = prepare_data(
args, device
)
embed_layer = rel_graph_embed(g, 128).to(device)
model = EntityClassify(g, 128, num_classes).to(device)
print(
f"Number of embedding parameters: {sum(p.numel() for p in embed_layer.parameters())}"
)
print(
f"Number of model parameters: {sum(p.numel() for p in model.parameters())}"
)
for run in range(args.runs):
try:
embed_layer.reset_parameters()
model.reset_parameters()
except:
# old pytorch version doesn't support reset_parameters() API
pass
# optimizer
all_params = itertools.chain(
model.parameters(), embed_layer.parameters()
)
optimizer = th.optim.Adam(all_params, lr=0.01)
if (
args.num_workers != 0
and device == "cpu"
and is_support_affinity(v_t)
):
expected_max = int(psutil.cpu_count(logical=False))
if args.num_workers >= expected_max:
print(
f"[ERROR] You specified num_workers are larger than physical cores, please set any number less than {expected_max}",
file=sys.stderr,
)
with train_loader.enable_cpu_affinity():
logger = train(
g,
model,
embed_layer,
optimizer,
train_loader,
split_idx,
labels,
logger,
device,
run,
)
else:
logger = train(
g,
model,
embed_layer,
optimizer,
train_loader,
split_idx,
labels,
logger,
device,
run,
)
logger.print_statistics(run)
print("Final performance: ")
logger.print_statistics()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RGCN")
parser.add_argument("--runs", type=int, default=10)
parser.add_argument("--num_workers", type=int, default=0)
args = parser.parse_args()
main(args)