-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathtrain.py
117 lines (89 loc) · 3.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from __future__ import division
from __future__ import print_function
from collections import defaultdict
import numpy as np
import scipy.sparse as sp
import torch
from torch.autograd import Variable
import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from utils import load_data, dotdict, eval_gae, make_sparse, plot_results
from models import GAE
from preprocessing import mask_test_edges, preprocess_graph
def main(args):
""" Train GAE """
print("Using {} dataset".format(args.dataset_str))
# Load data
np.random.seed(1)
adj, features = load_data(args.dataset_str)
N, D = features.shape
# Store original adjacency matrix (without diagonal entries)
adj_orig = adj
adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)
# Some preprocessing
adj_train_norm = preprocess_graph(adj_train)
adj_train_norm = Variable(make_sparse(adj_train_norm))
adj_train_labels = Variable(torch.FloatTensor(adj_train + sp.eye(adj_train.shape[0]).todense()))
features = Variable(make_sparse(features))
n_edges = adj_train_labels.sum()
data = {
'adj_norm' : adj_train_norm,
'adj_labels': adj_train_labels,
'features' : features,
}
gae = GAE(data,
n_hidden=32,
n_latent=16,
dropout=args.dropout)
optimizer = Adam({"lr": args.lr, "betas": (0.95, 0.999)})
svi = SVI(gae.model, gae.guide, optimizer, loss=Trace_ELBO())
# Results
results = defaultdict(list)
# Full batch training loop
for epoch in range(args.num_epochs):
# initialize loss accumulator
epoch_loss = 0.
# do ELBO gradient and accumulate loss
epoch_loss += svi.step()
# report training diagnostics
normalized_loss = epoch_loss / (2 * N * N)
results['train_elbo'].append(normalized_loss)
# Training loss
emb = gae.get_embeddings()
accuracy, roc_curr, ap_curr = eval_gae(val_edges, val_edges_false, emb, adj_orig)
results['accuracy_train'].append(accuracy)
results['roc_train'].append(roc_curr)
results['ap_train'].append(ap_curr)
print("Epoch:", '%04d' % (epoch + 1),
"train_loss=", "{:.5f}".format(normalized_loss),
"train_acc=", "{:.5f}".format(accuracy), "val_roc=", "{:.5f}".format(roc_curr), "val_ap=", "{:.5f}".format(ap_curr))
# Test loss
if epoch % args.test_freq == 0:
emb = gae.get_embeddings()
accuracy, roc_score, ap_score = eval_gae(test_edges, test_edges_false, emb, adj_orig)
results['accuracy_test'].append(accuracy)
results['roc_test'].append(roc_curr)
results['ap_test'].append(ap_curr)
print("Optimization Finished!")
# Test loss
emb = gae.get_embeddings()
accuracy, roc_score, ap_score = eval_gae(test_edges, test_edges_false, emb, adj_orig)
print('Test Accuracy: ' + str(accuracy))
print('Test ROC score: ' + str(roc_score))
print('Test AP score: ' + str(ap_score))
# Plot
plot_results(results, args.test_freq, path= args.dataset_str + "_results.png")
if __name__ == '__main__':
args = dotdict()
args.seed = 2
args.dropout = 0.0
args.num_epochs = 50
# args.dataset_str = 'cora'
args.dataset_str = 'citeseer'
args.test_freq = 10
args.lr = 0.01
pyro.clear_param_store()
np.random.seed(args.seed)
torch.manual_seed(args.seed)
main(args)