-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain_full_cora.py
147 lines (114 loc) · 4.14 KB
/
train_full_cora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# -*- coding: utf-8 -*-
"""
Created on Tue Jul 14 16:49:04 2020
@author: Ming Jin
Full graph training (Algorithm 1) on Cora dataset
For a large graph like Reddit, this approach will compromise so that we need sampling
** For simplicity, I haven't adapt to CUDA for this script **
Build version:
+ PyTorch 1.1.0
+ DGL 0.4.3.post2
"""
import torch
import torch.nn as nn
import networkx as nx
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
import time
import numpy as np
from SageConv import SAGEConv
class GraphSAGE(nn.Module):
'''
Full graph training SAGE network
(This version is much simpler than the sampling one)
'''
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
aggregator_type,
dropout = 0.5):
super(GraphSAGE, self).__init__()
self.layers = nn.ModuleList()
self.dropout = nn.Dropout(p=dropout)
# first layer
self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type, activation))
# hidden layers
for i in range(1, n_layers - 1):
self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type, activation))
# last layer
self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type, None))
def forward(self, g, features):
# Similar to GCN, 'g' is the entire graph, 'features' are node features
# returned h with the shape [num_nodes, n_classes]
h = features
for l, layer in enumerate(self.layers):
h = layer(g, h)
# we don't need activation and dropout for the last layer
if l != len(self.layers) - 1:
h = self.dropout(h)
return h
# create a network instance to train/evaluate on Cora
model = GraphSAGE(1433,
16,
7,
2, # n_layers >= 2
nn.ReLU(),
"mean")
def load_cora_data():
'''
Cora dataset function
'''
data = citegrh.load_cora()
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
# graph preprocess and calculate normalization factor
g = data.graph
g.remove_edges_from(nx.selfloop_edges(g))
g = DGLGraph(g)
# return graph, node features, labels, and training mask
return g, features, labels, mask, val_mask, test_mask
### train a 2-layer GraphSage on Cora dataset
g, features, labels, mask, val_mask, test_mask = load_cora_data()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003, weight_decay=5e-4)
loss_fcn = nn.CrossEntropyLoss()
def evaluate(model, graph, features, labels, mask):
model.eval()
with torch.no_grad():
# run pred on all nodes
logits = model(graph, features)
# but only evaluate on val or test nodes
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1) # predicted class index
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
dur = []
for epoch in range(200):
model.train()
if epoch >=3:
t0 = time.time()
optimizer.zero_grad()
# feed model with all nodes and features
# logits with the shape [num_nodes, n_class]
logits = model(g, features)
# but only train it on training nodes
# Notice: we haven't calculate loss on val/test nodes
loss = loss_fcn(logits[mask], labels[mask])
loss.backward()
optimizer.step()
if epoch >=3:
dur.append(time.time() - t0)
if epoch % 1 == 0:
acc = evaluate(model, g, features, labels, val_mask)
print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), acc, np.mean(dur)))
test_acc = evaluate(model, g, features, labels, test_mask)
print("\nTest accuracy {:.2%}".format(test_acc))