forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_ns.py
117 lines (99 loc) · 3.55 KB
/
train_ns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
Training and testing for node selection tasks in bAbI
"""
import argparse
import time
import numpy as np
import torch
from data_utils import get_babi_dataloaders
from ggnn_ns import NodeSelectionGGNN
from torch.optim import Adam
def main(args):
out_feats = {4: 4, 15: 5, 16: 6}
n_etypes = {4: 4, 15: 2, 16: 2}
train_dataloader, dev_dataloader, test_dataloaders = get_babi_dataloaders(
batch_size=args.batch_size,
train_size=args.train_num,
task_id=args.task_id,
q_type=args.question_id,
)
model = NodeSelectionGGNN(
annotation_size=1,
out_feats=out_feats[args.task_id],
n_steps=5,
n_etypes=n_etypes[args.task_id],
)
opt = Adam(model.parameters(), lr=args.lr)
print(f"Task {args.task_id}, question_id {args.question_id}")
print(f"Training set size: {len(train_dataloader.dataset)}")
print(f"Dev set size: {len(dev_dataloader.dataset)}")
# training and dev stage
for epoch in range(args.epochs):
model.train()
for i, batch in enumerate(train_dataloader):
g, labels = batch
loss, _ = model(g, labels)
opt.zero_grad()
loss.backward()
opt.step()
print(f"Epoch {epoch}, batch {i} loss: {loss.data}")
dev_preds = []
dev_labels = []
model.eval()
for g, labels in dev_dataloader:
with torch.no_grad():
preds = model(g)
preds = (
torch.tensor(preds, dtype=torch.long).data.numpy().tolist()
)
labels = labels.data.numpy().tolist()
dev_preds += preds
dev_labels += labels
acc = np.equal(dev_labels, dev_preds).astype(np.float).tolist()
acc = sum(acc) / len(acc)
print(f"Epoch {epoch}, Dev acc {acc}")
# test stage
for i, dataloader in enumerate(test_dataloaders):
print(f"Test set {i} size: {len(dataloader.dataset)}")
test_acc_list = []
for dataloader in test_dataloaders:
test_preds = []
test_labels = []
model.eval()
for g, labels in dataloader:
with torch.no_grad():
preds = model(g)
preds = (
torch.tensor(preds, dtype=torch.long).data.numpy().tolist()
)
labels = labels.data.numpy().tolist()
test_preds += preds
test_labels += labels
acc = np.equal(test_labels, test_preds).astype(np.float).tolist()
acc = sum(acc) / len(acc)
test_acc_list.append(acc)
test_acc_mean = np.mean(test_acc_list)
test_acc_std = np.std(test_acc_list)
print(
f"Mean of accuracy in 10 test datasets: {test_acc_mean}, std: {test_acc_std}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Gated Graph Neural Networks for node selection tasks in bAbI"
)
parser.add_argument(
"--task_id", type=int, default=16, help="task id from 1 to 20"
)
parser.add_argument(
"--question_id", type=int, default=1, help="question id for each task"
)
parser.add_argument(
"--train_num", type=int, default=50, help="Number of training examples"
)
parser.add_argument("--batch_size", type=int, default=10, help="batch size")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument(
"--epochs", type=int, default=100, help="number of training epochs"
)
args = parser.parse_args()
main(args)