-
Notifications
You must be signed in to change notification settings - Fork 0
/
cluster_gcn.py
95 lines (87 loc) · 3.29 KB
/
cluster_gcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
import time
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
self.dropout = nn.Dropout(0.5)
def forward(self, sg, x):
h = x
for l, layer in enumerate(self.layers):
h = layer(sg, h)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h
dataset = dgl.data.AsNodePredDataset(DglNodePropPredDataset('ogbn-products'))
graph = dataset[0] # already prepares ndata['label'/'train_mask'/'val_mask'/'test_mask']
model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda()
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
num_partitions = 1000
sampler = dgl.dataloading.ClusterGCNSampler(
graph, num_partitions,
prefetch_ndata=['feat', 'label', 'train_mask', 'val_mask', 'test_mask'])
# DataLoader for generic dataloading with a graph, a set of indices (any indices, like
# partition IDs here), and a graph sampler.
dataloader = dgl.dataloading.DataLoader(
graph,
torch.arange(num_partitions).to('cuda'),
sampler,
device='cuda',
batch_size=100,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=True)
durations = []
for _ in range(10):
t0 = time.time()
model.train()
for it, sg in enumerate(dataloader):
x = sg.ndata['feat']
y = sg.ndata['label']
m = sg.ndata['train_mask'].bool()
y_hat = model(sg, x)
loss = F.cross_entropy(y_hat[m], y[m])
opt.zero_grad()
loss.backward()
opt.step()
if it % 20 == 0:
acc = MF.accuracy(y_hat[m], y[m])
mem = torch.cuda.max_memory_allocated() / 1000000
print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB')
tt = time.time()
print(tt - t0)
durations.append(tt - t0)
model.eval()
with torch.no_grad():
val_preds, test_preds = [], []
val_labels, test_labels = [], []
for it, sg in enumerate(dataloader):
x = sg.ndata['feat']
y = sg.ndata['label']
m_val = sg.ndata['val_mask'].bool()
m_test = sg.ndata['test_mask'].bool()
y_hat = model(sg, x)
val_preds.append(y_hat[m_val])
val_labels.append(y[m_val])
test_preds.append(y_hat[m_test])
test_labels.append(y[m_test])
val_preds = torch.cat(val_preds, 0)
val_labels = torch.cat(val_labels, 0)
test_preds = torch.cat(test_preds, 0)
test_labels = torch.cat(test_labels, 0)
val_acc = MF.accuracy(val_preds, val_labels)
test_acc = MF.accuracy(test_preds, test_labels)
print('Validation acc:', val_acc.item(), 'Test acc:', test_acc.item())
print(np.mean(durations[4:]), np.std(durations[4:]))