forked from LuckyTiger123/DropMessage
-
Notifications
You must be signed in to change notification settings - Fork 7
/
train.py
117 lines (97 loc) · 4.97 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import sys
import torch
import argparse
from torch import Tensor
from torch_geometric.datasets import Planetoid
from torch_geometric.typing import Adj
import torch_geometric.transforms as T
import torch.nn.functional as F
sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
from src.layer import GNNLayer
# parse parameters
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--dataset', type=str, default='Cora',
help='The dataset to be trained on [Cora, Citeseer, PubMed].')
parser.add_argument('-c', '--cuda-device', type=int, default=0, help='which gpu device to use.')
parser.add_argument('-dr', '--dropping-rate', type=float, default=0, help='The chosen dropping rate (default: 0).')
parser.add_argument('-e', '--epoch', type=int, default=500, help='The epoch number (default: 500).')
parser.add_argument('-bb', '--backbone', type=str, default='GCN', help='The backbone model [GCN, GAT, APPNP].')
parser.add_argument('-dm', '--dropping-method', type=str, default='DropMessage',
help='The chosen dropping method [Dropout, DropEdge, DropNode, DropMessage].')
parser.add_argument('-hs', '--heads', type=int, default=1, help='The head number for GAT (default: 1).')
parser.add_argument('-k', '--K', type=int, default=10, help='The K value for APPNP (default: 10).')
parser.add_argument('-a', '--alpha', type=float, default=0.1, help='The alpha value for APPNP (default: 0.1).')
parser.add_argument('-fyd', '--first-layer-dimension', type=int, default=16,
help='The hidden dimension number for two-layer GNNs (default: 16).')
parser.add_argument('-r', '--rand-seed', type=int, default=0, help='The random seed (default: 0).')
args = parser.parse_args()
# random seed setting
random_seed = args.rand_seed
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# device selection
device = torch.device('cuda:{}'.format(args.cuda_device) if torch.cuda.is_available() else 'cpu')
# load dataset
dataset = Planetoid(root="./data", name=args.dataset, transform=T.NormalizeFeatures())
data = dataset[0].to(device=device)
# Model
class Model(torch.nn.Module):
def __init__(self, feature_num, output_num, backbone, dropping_method):
super(Model, self).__init__()
self.backbone = backbone
self.gnn1 = GNNLayer(feature_num, args.first_layer_dimension, dropping_method, backbone, heads=args.heads,
alpha=args.alpha, K=args.K)
self.gnn2 = GNNLayer(args.first_layer_dimension * args.heads, output_num, dropping_method, backbone,
alpha=args.alpha, K=args.K)
def forward(self, x: Tensor, edge_index: Adj, drop_rate: float = 0):
x = self.gnn1(x, edge_index, drop_rate)
if self.backbone == 'GAT':
x = F.elu(x)
else:
x = F.relu(x)
x = self.gnn2(x, edge_index, drop_rate)
return x
def reset_parameters(self):
self.gnn1.reset_parameters()
self.gnn2.reset_parameters()
model = Model(dataset.num_features, dataset.num_classes, args.backbone, args.dropping_method).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.0005)
epoch_num = args.epoch
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index, args.dropping_rate)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
print('the train loss is {}'.format(float(loss)))
optimizer.step()
@torch.no_grad()
def test():
model.eval()
out = model(data.x, data.edge_index)
_, pred = out.max(dim=1)
train_correct = int(pred[data.train_mask].eq(data.y[data.train_mask]).sum().item())
train_acc = train_correct / int(data.train_mask.sum())
validate_correct = int(pred[data.val_mask].eq(data.y[data.val_mask]).sum().item())
validate_acc = validate_correct / int(data.val_mask.sum())
test_correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
test_acc = test_correct / int(data.test_mask.sum())
return train_acc, validate_acc, test_acc
best_val_acc = test_acc = 0
for epoch in range(epoch_num):
train()
train_acc, val_acc, current_test_acc = test()
print('For the {} epoch, the train acc is {}, the val acc is {}, the test acc is {}.'.format(epoch, train_acc,
val_acc,
current_test_acc))
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = current_test_acc
print('Mission completes.')
print('--------------------------------------------------------------------------')
print('Dataset: {}.'.format(args.dataset))
print('Backbone model: {}. Dropping method: {}.'.format(args.backbone, args.dropping_method))
print('The final test acc is {}.'.format(test_acc))