-
Notifications
You must be signed in to change notification settings - Fork 467
/
Copy pathmain.py
73 lines (61 loc) · 2.74 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#coding: utf-8
"""
基于Cora的GraphSage示例
"""
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from net import GraphSage
from data import CoraData
from sampling import multihop_sampling
from collections import namedtuple
INPUT_DIM = 1433 # 输入维度
# Note: 采样的邻居阶数需要与GCN的层数保持一致
HIDDEN_DIM = [128, 7] # 隐藏单元节点数
NUM_NEIGHBORS_LIST = [10, 10] # 每阶采样邻居的节点数
assert len(HIDDEN_DIM) == len(NUM_NEIGHBORS_LIST)
BTACH_SIZE = 16 # 批处理大小
EPOCHS = 20
NUM_BATCH_PER_EPOCH = 20 # 每个epoch循环的批次数
LEARNING_RATE = 0.01 # 学习率
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
Data = namedtuple('Data', ['x', 'y', 'adjacency_dict',
'train_mask', 'val_mask', 'test_mask'])
data = CoraData().data
x = data.x / data.x.sum(1, keepdims=True) # 归一化数据,使得每一行和为1
train_index = np.where(data.train_mask)[0]
train_label = data.y
test_index = np.where(data.test_mask)[0]
model = GraphSage(input_dim=INPUT_DIM, hidden_dim=HIDDEN_DIM,
num_neighbors_list=NUM_NEIGHBORS_LIST).to(DEVICE)
print(model)
criterion = nn.CrossEntropyLoss().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=5e-4)
def train():
model.train()
for e in range(EPOCHS):
for batch in range(NUM_BATCH_PER_EPOCH):
batch_src_index = np.random.choice(train_index, size=(BTACH_SIZE,))
batch_src_label = torch.from_numpy(train_label[batch_src_index]).long().to(DEVICE)
batch_sampling_result = multihop_sampling(batch_src_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
batch_sampling_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in batch_sampling_result]
batch_train_logits = model(batch_sampling_x)
loss = criterion(batch_train_logits, batch_src_label)
optimizer.zero_grad()
loss.backward() # 反向传播计算参数的梯度
optimizer.step() # 使用优化方法进行梯度更新
print("Epoch {:03d} Batch {:03d} Loss: {:.4f}".format(e, batch, loss.item()))
test()
def test():
model.eval()
with torch.no_grad():
test_sampling_result = multihop_sampling(test_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
test_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in test_sampling_result]
test_logits = model(test_x)
test_label = torch.from_numpy(data.y[test_index]).long().to(DEVICE)
predict_y = test_logits.max(1)[1]
accuarcy = torch.eq(predict_y, test_label).float().mean().item()
print("Test Accuracy: ", accuarcy)
if __name__ == '__main__':
train()